diff --git a/Project.toml b/Project.toml index c57589f5a50edb2b1bd7ca8a2d4e97e95a0c0b4b..cc58ad66c61be5b5b9ac1e8f8291ed03a889e0e7 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,7 @@ MLJEnsembles = "50ed68f4-41fd-4504-931a-ed422449fee0" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" PkgTemplates = "14b8a8f1-9102-5b29-a752-f990bacb7fe1" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/experiments/Manifest.toml b/experiments/Manifest.toml index 1036b7b143def486f53bf759025b771a7814311f..db3fc404b5bee0f29dbc4173f7639a9aabb5d337 100644 --- a/experiments/Manifest.toml +++ b/experiments/Manifest.toml @@ -417,9 +417,11 @@ version = "0.6.3" [[deps.CounterfactualExplanations]] deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "PrecompileTools", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "UUIDs", "cuDNN"] -path = "../../CounterfactualExplanations.jl" +git-tree-sha1 = "14da4a8ea118b96c2477b05d5bc1c353c1d80e79" +repo-rev = "main" +repo-url = "https://github.com/JuliaTrustworthyAI/CounterfactualExplanations.jl.git" uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" -version = "0.1.26" +version = "0.1.28" [deps.CounterfactualExplanations.extensions] MPIExt = "MPI" @@ -554,7 +556,7 @@ uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" version = "0.6.8" [[deps.ECCCo]] -deps = ["CategoricalArrays", "ChainRules", "ConformalPrediction", "CounterfactualExplanations", "Distances", "Distributions", "Flux", "Images", "JointEnergyModels", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "Parameters", "PkgTemplates", "Random", "Statistics", "StatsBase", "Term", "cuDNN"] +deps = ["CategoricalArrays", "ChainRules", "ConformalPrediction", "CounterfactualExplanations", "Distances", "Distributions", "Flux", "Images", "JointEnergyModels", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "MultivariateStats", "Parameters", "PkgTemplates", "Random", "Statistics", "StatsBase", "Term", "cuDNN"] path = ".." uuid = "0232c203-4013-4b0d-ad96-43e3e11ac3bf" version = "0.1.0" diff --git a/experiments/benchmarking/benchmarking.jl b/experiments/benchmarking/benchmarking.jl index 6e5ca946f17c62cb35e37be9e9fd7942d1f7b926..91dca4127a8ef02e12b7b386732039d28e913932 100644 --- a/experiments/benchmarking/benchmarking.jl +++ b/experiments/benchmarking/benchmarking.jl @@ -8,6 +8,7 @@ function default_generators(; nsamples::Union{Nothing,Int}=nothing, nmin::Union{Nothing,Int}=nothing, reg_strength::Real=0.5, + dim_reduction::Union{Nothing,Int}=nothing, ) @info "Begin benchmarking counterfactual explanations." @@ -35,6 +36,19 @@ function default_generators(; "ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength=reg_strength), ) end + + # Dimensionality reduction: + # If dimensionality reduction is specified, add ECCCo-Δ (latent) to the generator dictionary: + if dim_reduction + eccco_latent = Dict( + "ECCCo-Δ (latent)" => ECCCoGenerator( + λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength=reg_strength, + dim_reduction=dim_reduction + ) + ) + generator_dict = merge(generator_dict, eccco_latent) + end + return generator_dict end @@ -66,6 +80,7 @@ function run_benchmark(exper::Experiment, model_dict::Dict) nsamples=exper.nsamples, nmin=exper.nmin, reg_strength=exper.reg_strength, + dim_reduction=exper.dim_reduction, ) end diff --git a/experiments/experiment.jl b/experiments/experiment.jl index 6fb0e711d8128c9f02f65d62667d9c4b40cacab5..34202214127fa028f2d13cdedf1f2fe692f572a0 100644 --- a/experiments/experiment.jl +++ b/experiments/experiment.jl @@ -42,6 +42,7 @@ Base.@kwdef struct Experiment model_tuning_params::NamedTuple = DEFAULT_MODEL_TUNING_SMALL use_tuned::Bool = true store_ce::Bool = STORE_CE + dim_reduction::Bool = false end "A container to hold the results of an experiment." diff --git a/experiments/mnist.jl b/experiments/mnist.jl index 7ee81a68e81b4ce15d40cc8b5eb8187fc0d7b6ca..14cc424470bf9b3ac774402a35e91dc374e335eb 100644 --- a/experiments/mnist.jl +++ b/experiments/mnist.jl @@ -35,7 +35,7 @@ ce_measures = [CE_MEASURES..., ECCCo.distance_from_energy_ssim, ECCCo.distance_f # Parameter choices: params = ( - n_individuals=N_IND_SPECIFIED ? N_IND : 100, + n_individuals=N_IND_SPECIFIED ? N_IND : 2, builder=default_builder(n_hidden=128, n_layers=1, activation=Flux.swish), ð’Ÿx=Uniform(-1.0, 1.0), α=[1.0, 1.0, 1e-2], @@ -48,11 +48,12 @@ params = ( nsamples=10, nmin=1, niter_eccco=10, - Λ=[0.005, 0.25, 0.25], - Λ_Δ=[0.005, 0.1, 0.5], + Λ=[0.01, 0.25, 0.25], + Λ_Δ=[0.01, 0.1, 0.3], opt=Flux.Optimise.Descent(0.1), reg_strength = 0.0, ce_measures=ce_measures, + dim_reduction=true, ) if !GRID_SEARCH diff --git a/src/generator.jl b/src/generator.jl index 725cf22771f95d624181bc7574f93eff28747c8b..5974f51355c41978ff43710cf05dc448814d4305 100644 --- a/src/generator.jl +++ b/src/generator.jl @@ -13,6 +13,7 @@ function ECCCoGenerator(; nmin::Union{Nothing,Int}=nothing, niter::Union{Nothing,Int}=nothing, reg_strength::Real=0.1, + dim_reduction::Bool=false, kwargs... ) @@ -44,5 +45,5 @@ function ECCCoGenerator(; λ = λ isa AbstractFloat ? [0.0, λ, λ] : λ # Generator - return GradientBasedGenerator(; loss=loss_fun, penalty=_penalties, λ=λ, opt=opt, kwargs...) + return GradientBasedGenerator(; loss=loss_fun, penalty=_penalties, λ=λ, opt=opt, dim_reduction=dim_reduction, kwargs...) end \ No newline at end of file diff --git a/src/penalties.jl b/src/penalties.jl index 423d955cd7af2ec6eb07a8105aa85a87b8f31ec7..c4272dd5ba0827a8ace1e42686c3dd9b0b4e24a0 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -4,6 +4,7 @@ using Distances using Flux using Images: assess_ssim using LinearAlgebra: norm +using MultivariateStats using Statistics: mean """ @@ -74,10 +75,10 @@ Computes the distance from the counterfactual to generated conditional samples. """ function distance_from_energy( ce::AbstractCounterfactualExplanation; - n::Int=10, niter=500, from_buffer=true, agg=mean, + n::Int=50, niter=500, from_buffer=true, agg=mean, choose_lowest_energy=true, choose_random=false, - nmin::Int=10, + nmin::Int=25, return_conditionals=false, p::Int=1, kwargs... @@ -127,10 +128,10 @@ Computes the cosine distance from the counterfactual to generated conditional sa """ function distance_from_energy_cosine( ce::AbstractCounterfactualExplanation; - n::Int=10, niter=500, from_buffer=true, agg=mean, + n::Int=50, niter=500, from_buffer=true, agg=mean, choose_lowest_energy=true, choose_random=false, - nmin::Int=10, + nmin::Int=25, return_conditionals=false, kwargs... ) @@ -177,10 +178,10 @@ Computes 1-SSIM from the counterfactual to generated conditional samples where S """ function distance_from_energy_ssim( ce::AbstractCounterfactualExplanation; - n::Int=10, niter=500, from_buffer=true, agg=mean, + n::Int=50, niter=500, from_buffer=true, agg=mean, choose_lowest_energy=true, choose_random=false, - nmin::Int=10, + nmin::Int=25, return_conditionals=false, kwargs... )