Skip to content
Snippets Groups Projects
Commit b0e611e5 authored by Pat Alt's avatar Pat Alt
Browse files

sorting out dim reduction

parent 36115c9a
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
......@@ -19,6 +19,7 @@ MLJEnsembles = "50ed68f4-41fd-4504-931a-ed422449fee0"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
PkgTemplates = "14b8a8f1-9102-5b29-a752-f990bacb7fe1"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
......
......@@ -417,9 +417,11 @@ version = "0.6.3"
[[deps.CounterfactualExplanations]]
deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "PrecompileTools", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "UUIDs", "cuDNN"]
path = "../../CounterfactualExplanations.jl"
git-tree-sha1 = "14da4a8ea118b96c2477b05d5bc1c353c1d80e79"
repo-rev = "main"
repo-url = "https://github.com/JuliaTrustworthyAI/CounterfactualExplanations.jl.git"
uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
version = "0.1.26"
version = "0.1.28"
[deps.CounterfactualExplanations.extensions]
MPIExt = "MPI"
......@@ -554,7 +556,7 @@ uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
version = "0.6.8"
[[deps.ECCCo]]
deps = ["CategoricalArrays", "ChainRules", "ConformalPrediction", "CounterfactualExplanations", "Distances", "Distributions", "Flux", "Images", "JointEnergyModels", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "Parameters", "PkgTemplates", "Random", "Statistics", "StatsBase", "Term", "cuDNN"]
deps = ["CategoricalArrays", "ChainRules", "ConformalPrediction", "CounterfactualExplanations", "Distances", "Distributions", "Flux", "Images", "JointEnergyModels", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "MultivariateStats", "Parameters", "PkgTemplates", "Random", "Statistics", "StatsBase", "Term", "cuDNN"]
path = ".."
uuid = "0232c203-4013-4b0d-ad96-43e3e11ac3bf"
version = "0.1.0"
......
......@@ -8,6 +8,7 @@ function default_generators(;
nsamples::Union{Nothing,Int}=nothing,
nmin::Union{Nothing,Int}=nothing,
reg_strength::Real=0.5,
dim_reduction::Union{Nothing,Int}=nothing,
)
@info "Begin benchmarking counterfactual explanations."
......@@ -35,6 +36,19 @@ function default_generators(;
"ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength=reg_strength),
)
end
# Dimensionality reduction:
# If dimensionality reduction is specified, add ECCCo-Δ (latent) to the generator dictionary:
if dim_reduction
eccco_latent = Dict(
"ECCCo-Δ (latent)" => ECCCoGenerator(
λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength=reg_strength,
dim_reduction=dim_reduction
)
)
generator_dict = merge(generator_dict, eccco_latent)
end
return generator_dict
end
......@@ -66,6 +80,7 @@ function run_benchmark(exper::Experiment, model_dict::Dict)
nsamples=exper.nsamples,
nmin=exper.nmin,
reg_strength=exper.reg_strength,
dim_reduction=exper.dim_reduction,
)
end
......
......@@ -42,6 +42,7 @@ Base.@kwdef struct Experiment
model_tuning_params::NamedTuple = DEFAULT_MODEL_TUNING_SMALL
use_tuned::Bool = true
store_ce::Bool = STORE_CE
dim_reduction::Bool = false
end
"A container to hold the results of an experiment."
......
......@@ -35,7 +35,7 @@ ce_measures = [CE_MEASURES..., ECCCo.distance_from_energy_ssim, ECCCo.distance_f
# Parameter choices:
params = (
n_individuals=N_IND_SPECIFIED ? N_IND : 100,
n_individuals=N_IND_SPECIFIED ? N_IND : 2,
builder=default_builder(n_hidden=128, n_layers=1, activation=Flux.swish),
𝒟x=Uniform(-1.0, 1.0),
α=[1.0, 1.0, 1e-2],
......@@ -48,11 +48,12 @@ params = (
nsamples=10,
nmin=1,
niter_eccco=10,
Λ=[0.005, 0.25, 0.25],
Λ_Δ=[0.005, 0.1, 0.5],
Λ=[0.01, 0.25, 0.25],
Λ_Δ=[0.01, 0.1, 0.3],
opt=Flux.Optimise.Descent(0.1),
reg_strength = 0.0,
ce_measures=ce_measures,
dim_reduction=true,
)
if !GRID_SEARCH
......
......@@ -13,6 +13,7 @@ function ECCCoGenerator(;
nmin::Union{Nothing,Int}=nothing,
niter::Union{Nothing,Int}=nothing,
reg_strength::Real=0.1,
dim_reduction::Bool=false,
kwargs...
)
......@@ -44,5 +45,5 @@ function ECCCoGenerator(;
λ = λ isa AbstractFloat ? [0.0, λ, λ] : λ
# Generator
return GradientBasedGenerator(; loss=loss_fun, penalty=_penalties, λ=λ, opt=opt, kwargs...)
return GradientBasedGenerator(; loss=loss_fun, penalty=_penalties, λ=λ, opt=opt, dim_reduction=dim_reduction, kwargs...)
end
\ No newline at end of file
......@@ -4,6 +4,7 @@ using Distances
using Flux
using Images: assess_ssim
using LinearAlgebra: norm
using MultivariateStats
using Statistics: mean
"""
......@@ -74,10 +75,10 @@ Computes the distance from the counterfactual to generated conditional samples.
"""
function distance_from_energy(
ce::AbstractCounterfactualExplanation;
n::Int=10, niter=500, from_buffer=true, agg=mean,
n::Int=50, niter=500, from_buffer=true, agg=mean,
choose_lowest_energy=true,
choose_random=false,
nmin::Int=10,
nmin::Int=25,
return_conditionals=false,
p::Int=1,
kwargs...
......@@ -127,10 +128,10 @@ Computes the cosine distance from the counterfactual to generated conditional sa
"""
function distance_from_energy_cosine(
ce::AbstractCounterfactualExplanation;
n::Int=10, niter=500, from_buffer=true, agg=mean,
n::Int=50, niter=500, from_buffer=true, agg=mean,
choose_lowest_energy=true,
choose_random=false,
nmin::Int=10,
nmin::Int=25,
return_conditionals=false,
kwargs...
)
......@@ -177,10 +178,10 @@ Computes 1-SSIM from the counterfactual to generated conditional samples where S
"""
function distance_from_energy_ssim(
ce::AbstractCounterfactualExplanation;
n::Int=10, niter=500, from_buffer=true, agg=mean,
n::Int=50, niter=500, from_buffer=true, agg=mean,
choose_lowest_energy=true,
choose_random=false,
nmin::Int=10,
nmin::Int=25,
return_conditionals=false,
kwargs...
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment