diff --git a/artifacts/results/images/mnist_generated_JEM Ensemble.png b/artifacts/results/images/mnist_generated_JEM Ensemble.png new file mode 100644 index 0000000000000000000000000000000000000000..015cddd10299c4734b5faffcce4585b0582150a4 Binary files /dev/null and b/artifacts/results/images/mnist_generated_JEM Ensemble.png differ diff --git a/artifacts/results/images/mnist_generated_JEM.png b/artifacts/results/images/mnist_generated_JEM.png new file mode 100644 index 0000000000000000000000000000000000000000..6ff8e5780dae3cd4ac3c775c7faf31801ccc6f5a Binary files /dev/null and b/artifacts/results/images/mnist_generated_JEM.png differ diff --git a/artifacts/results/images/mnist_generated_MLP Ensemble.png b/artifacts/results/images/mnist_generated_MLP Ensemble.png new file mode 100644 index 0000000000000000000000000000000000000000..4834dde4916fd890c7b13523148ffee9405ef0c7 Binary files /dev/null and b/artifacts/results/images/mnist_generated_MLP Ensemble.png differ diff --git a/artifacts/results/images/mnist_generated_MLP.png b/artifacts/results/images/mnist_generated_MLP.png new file mode 100644 index 0000000000000000000000000000000000000000..02e86da900666b6c564655448a764fc347228264 Binary files /dev/null and b/artifacts/results/images/mnist_generated_MLP.png differ diff --git a/artifacts/results/mnist_model_performance.csv b/artifacts/results/mnist_model_performance.csv new file mode 100644 index 0000000000000000000000000000000000000000..19ef85f908ff9b49191bb26f1d7f5b9286991176 --- /dev/null +++ b/artifacts/results/mnist_model_performance.csv @@ -0,0 +1,5 @@ +acc,precision,f1score,mod_name +0.91,0.9103128613256043,0.9085266898485427,JEM Ensemble +0.9423,0.9418035808913033,0.9415726345973308,MLP +0.9439,0.943391966967219,0.9432116151207373,MLP Ensemble +0.8748,0.8798390712341672,0.8728379219312089,JEM diff --git a/artifacts/results/mnist_model_performance.jls b/artifacts/results/mnist_model_performance.jls new file mode 100644 index 0000000000000000000000000000000000000000..9e6a885e0eeb8b869672d63905a5644cbde12436 Binary files /dev/null and b/artifacts/results/mnist_model_performance.jls differ diff --git a/artifacts/results/mnist_models.jls b/artifacts/results/mnist_models.jls index 75cab5e0485a29b005e2053eb200c9dbb55e468f..df1acd84992a396709245a2a93d369f665f85622 100644 Binary files a/artifacts/results/mnist_models.jls and b/artifacts/results/mnist_models.jls differ diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls index 0dbb47ae5322013c4017e5495b4b7afc1c8d5ad5..6167b3c118ef4be2c37a9a9a4f84af34b3c5d260 100644 Binary files a/artifacts/results/mnist_vae.jls and b/artifacts/results/mnist_vae.jls differ diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls index e569cb9fc2f85467a252315b1924fbf94eb7fd96..b81b474a74b5ffd83b8650a49727326eb905b732 100644 Binary files a/artifacts/results/mnist_vae_weak.jls and b/artifacts/results/mnist_vae_weak.jls differ diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 696cd7861a6a3d5bfc4dd8f2a0d822b24dbd3f7a..7cf36c56db015edbd0be538cb7ee9e6f18922ba4 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -138,7 +138,7 @@ end ```{julia} # Data: -n_obs = 1000 +n_obs = 10000 counterfactual_data = load_mnist(n_obs) counterfactual_data.X = pre_process.(counterfactual_data.X) X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) @@ -242,13 +242,11 @@ Serialization.serialize(joinpath(output_path,"mnist_models.jls"), model_dict) ``` ```{julia} -for (mod_name, mod) in model_dict +# Plot generated samples: - # Plot: - if mod.model.model isa JointEnergyClassifier - sampler = mod.model.model.jem.sampler - elseif mod.model.model.model isa JointEnergyClassifier - sampler = mod.model.model.model.jem.sampler +for (mod_name, mod) in model_dict + if ECCCo._has_sampler(mod) + sampler = ECCCo._get_sampler(mod) else K = length(counterfactual_data.y_levels) input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1)) @@ -275,14 +273,90 @@ for (mod_name, mod) in model_dict plts = [plts..., plt] end plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1), plot_title=mod_name) + savefig(plt, joinpath(output_images_path, "mnist_generated_$(mod_name).png")) display(plt) +end +``` +```{julia} +# Evaluate models: + +measure = Dict( + :f1score => multiclass_f1score, + :acc => accuracy, + :precision => multiclass_precision +) +model_performance = DataFrame() +for (mod_name, mod) in model_dict # Test performance: test_data = load_mnist_test() test_data.X = pre_process.(test_data.X) - f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data) - println("F1 score (test): $(round(f1,digits=3))") + _perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure))) + _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) + _perf.mod_name .= mod_name + model_performance = vcat(model_performance, _perf) end +Serialization.serialize(joinpath(output_path,"mnist_model_performance.jls"), model_performance) +CSV.write(joinpath(output_path, "mnist_model_performance.csv"), model_performance) +model_performance +``` + +```{julia} +Random.seed!(123) + +# Set up search: +factual = 8 +x = reshape(counterfactual_data.X[:,rand(findall(labels.==factual))],input_dim,1) +target = 3 +γ = 0.9 +T = 100 + +# ECCCo: +λ=[0.5,0.1,0.5] +temp=0.1 +η=0.01 + +# Generate counterfactuals using ECCCo generator: +generator = ECCCoGenerator( + λ=λ, + temp=temp, + opt=Flux.Optimise.Adam(η), +) + +ces = Dict() +for (mod_name, mod) in model_dict + ce = generate_counterfactual( + x, target, counterfactual_data, mod, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, + converge_when=:generator_conditions, + ) + ces[mod_name] = ce +end + +# Plot: +p1 = Plots.plot( + convert2image(MNIST, reshape(x,28,28)), + axis=nothing, + size=(img_height, img_height), + title="Factual" +) +plts = [p1] + +for (_name,ce) in ces + x = CounterfactualExplanations.counterfactual(ce) + _phat = target_probs(ce) + _title = "$_name (p̂=$(round(_phat[1]; digits=3)))" + plt = Plots.plot( + convert2image(MNIST, reshape(x,28,28)), + axis=nothing, + size=(img_height, img_height), + title=_title + ) + plts = [plts..., plt] +end +plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) +display(plt) ``` ```{julia} diff --git a/notebooks/setup.jl b/notebooks/setup.jl index a377a5880eef6b65f10082c6a59916a23a371f88..8d711c1f4d85d3278265a48f477b98b1c3a35dce 100644 --- a/notebooks/setup.jl +++ b/notebooks/setup.jl @@ -39,6 +39,7 @@ setup_notebooks = quote Random.seed!(2023) www_path = "www" output_path = "artifacts/results" + output_images_path = "artifacts/results/images" img_height = 300 end; \ No newline at end of file diff --git a/paper/paper.pdf b/paper/paper.pdf index fdebedc3bdbb524bbfe1d6be1cf98a98524f4baa..6948356686c7558b15c1d17fd67d8c7613b6f4da 100644 Binary files a/paper/paper.pdf and b/paper/paper.pdf differ diff --git a/paper/paper.tex b/paper/paper.tex index 3984524dae468dc4ce963e696442228850e55213..fb8ddcdfc1de94d649676f5f0e78094419d70761 100644 --- a/paper/paper.tex +++ b/paper/paper.tex @@ -266,7 +266,8 @@ As noted by \citet{guidotti2022counterfactual}, these distance-based measures ar \item BatchNorm does not seem compatible with JEM \item Coverage and temperature impacts CCE in somewhat unpredictable ways \item It seems that models that are not explicitly trained for generative task, still learn it implictly - \item Batch size seems to impact quality of generated samples + \item Batch size seems to impact quality of generated samples (at inference, but not so much during JEM training) + \item ECCCo is sensitive to optimizer (Adam works well), learning rate and distance metric (l1 works well) \end{itemize} \section{Discussion} diff --git a/src/generator.jl b/src/generator.jl index 72952aa31de89bf7d1a8d297a8337abe97cf9a16..87b1dceff3e45e92075efdfb5a69f614eb15e2af 100644 --- a/src/generator.jl +++ b/src/generator.jl @@ -28,7 +28,7 @@ function ECCCoGenerator(; function _set_size_penalty(ce::AbstractCounterfactualExplanation) return ECCCo.set_size_penalty(ce; κ=κ, temp=temp) end - _penalties = [Objectives.distance_l2, _set_size_penalty, ECCCo.distance_from_energy] + _penalties = [Objectives.distance_l1, _set_size_penalty, ECCCo.distance_from_energy] λ = λ isa AbstractFloat ? [0.0, λ, λ] : λ return Generator(; penalty=_penalties, λ=λ, opt=opt, kwargs...) end diff --git a/src/model.jl b/src/model.jl index 38188816149952099159cc2b29334622c958b727..f2540df96f23f970da3827860b7187fb691b8839 100644 --- a/src/model.jl +++ b/src/model.jl @@ -60,6 +60,33 @@ function _outdim(fitresult) return outdim end +""" + _get_sampler(model::ConformalModel) + +Private helper function that extracts the sampler from a fitted model. +""" +function _get_sampler(model::ConformalModel) + _mod = model.model + if _mod.model isa MLJEnsembles.EitherEnsembleModel + _mod = _mod.model + end + if _mod.model isa JointEnergyClassifier + sampler = _mod.model.sampler + else + sampler = false + end + return sampler +end + +""" + _has_sampler(model::ConformalModel) + +Private helper function that checks if a fitted model has a sampler. +""" +function _has_sampler(model::ConformalModel) + return !(_get_sampler(model) isa Bool) +end + """ ConformalModel(model, fitresult=nothing; likelihood::Union{Nothing,Symbol}=nothing) diff --git a/src/penalties.jl b/src/penalties.jl index ccd1c8a911e1136f82422b8ee4ca200b58b47b8b..d3fdf8e15e6ba8703e5e7dc2cd8583f4a278713e 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -51,7 +51,6 @@ function distance_from_energy( x′ = CounterfactualExplanations.counterfactual(ce) loss = map(eachslice(x′, dims=ndims(x′))) do x Δ = map(eachcol(conditional_samples[1])) do xsample - # 1 .- (x'xsample)/(norm(x)*norm(xsample)) norm(x - xsample, 1) end return mean(Δ) diff --git a/src/sampling.jl b/src/sampling.jl index 3cd4b60470fc3d42e37300ddfe5bbcf903611c3b..2908f1d63a49f7bc5eebb90c839893bf5c6f5c8c 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -43,8 +43,8 @@ function EnergySampler( @assert y ∈ data.y_levels || y ∈ 1:length(data.y_levels) - if model.model.model isa JointEnergyClassifier - sampler = model.model.model.jem.sampler + if ECCCo._has_sampler(model) + sampler = ECCCo._get_sampler(model) else K = length(data.y_levels) input_size = size(selectdim(data.X, ndims(data.X), 1)) @@ -57,8 +57,10 @@ function EnergySampler( # Initiate: energy_sampler = EnergySampler(model, data, sampler, opt, nothing, nothing) - # Generate conditional samples: - generate_samples!(energy_sampler, nsamples, yidx; niter=niter) + # Generate conditional samples (one at a time): + for i in 1:nsamples + generate_samples!(energy_sampler, 1, yidx; niter=niter) + end return energy_sampler end @@ -106,7 +108,11 @@ end Generates `n` samples from `EnergySampler` for conditioning value `y`. Assigns samples and conditioning value to `EnergySampler`. """ function generate_samples!(e::EnergySampler, n::Int, y::Int; niter::Int=100) - e.buffer = generate_samples(e, n, y; niter=niter) + if isnothing(e.buffer) + e.buffer = generate_samples(e, n, y; niter=niter) + else + e.buffer = cat(e.buffer, generate_samples(e, n, y; niter=niter), dims=ndims(e.buffer)) + end e.yidx = y end