diff --git a/artifacts/results/mnist_architectures.jls b/artifacts/results/mnist_architectures.jls index 4054e47c244cf42615793282663b50e07abf4755..c227bf2cec8ce4c773d4a25f0e73440f03486b4c 100644 Binary files a/artifacts/results/mnist_architectures.jls and b/artifacts/results/mnist_architectures.jls differ diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls index 1a74f76e4958d83d3fc08c1fe1c247292e07f297..3537f90018d2104bc168f3b6d848f7030487fcaa 100644 Binary files a/artifacts/results/mnist_vae.jls and b/artifacts/results/mnist_vae.jls differ diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls index 2100c22bdb104119ce7eb024a9d24da877da7105..96a8d445d9ec1ad815bf56539ae121d82f7e742c 100644 Binary files a/artifacts/results/mnist_vae_weak.jls and b/artifacts/results/mnist_vae_weak.jls differ diff --git a/notebooks/.CondaPkg/env/conda-meta/history b/notebooks/.CondaPkg/env/conda-meta/history index dd869616bef06f2eeffd268caa0a34b25e3af339..6c7683b9e57e48bee3f8f60a08c00c9a8f1c2d8b 100644 --- a/notebooks/.CondaPkg/env/conda-meta/history +++ b/notebooks/.CondaPkg/env/conda-meta/history @@ -1,4 +1,4 @@ -==> 2023-08-08 15:03:23 <== +==> 2023-08-10 11:22:12 <== # cmd: /Users/FA31DU/.julia/artifacts/6ecf04294c7f327e02e84972f34835649a5eb35e/bin/micromamba -r /Users/FA31DU/.julia/scratchspaces/0b3b1443-0f03-428d-bdfb-f27f9c1191ea/root create -y -p /Users/FA31DU/code/ECCCo.jl/notebooks/.CondaPkg/env --override-channels --no-channel-priority numpy[version='*'] pip[version='>=22.0.0'] python[version='>=3.7,<4',channel='conda-forge',build='*cpython*'] -c conda-forge # conda version: 3.8.0 +https://conda.anaconda.org/conda-forge/osx-64::xz-5.2.6-h775f41a_0 diff --git a/notebooks/.CondaPkg/meta b/notebooks/.CondaPkg/meta index 1b11b47e6fbbf80a0ee54e97be26195177f40138..6b7ba5c59ae065f773611edc2fc41659836fb772 100644 Binary files a/notebooks/.CondaPkg/meta and b/notebooks/.CondaPkg/meta differ diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 2d88f306a27ef313f91998472fa8b4fa17461300..91b18b157d88c906f57a626f42b6d10e177cf814 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -164,8 +164,8 @@ end ```{julia} # Hyper: -_retrain = true -_regen = true +_retrain = false +_regen = false # Data: n_obs = 10000 @@ -357,6 +357,17 @@ model_performance ### Different Models ```{julia} +function plot_mnist(ce; size=(img_height, img_height), kwrgs...) + x = CounterfactualExplanations.counterfactual(ce) + phat = target_probs(ce) + plt = Plots.plot( + convert2image(MNIST, reshape(x,28,28)); + axis=([], false), + size=size, + kwrgs..., + ) +end + function _plot_eccco_mnist( x::Union{AbstractArray, Int}=x_factual, target::Int=target; λ=[0.1,0.25,0.25], @@ -372,6 +383,7 @@ function _plot_eccco_mnist( plot_factual::Bool = false, generator::Union{Nothing,CounterfactualExplanations.AbstractGenerator}=nothing, test_data::Bool = false, + use_energy_delta::Bool = false, kwrgs..., ) @@ -392,6 +404,7 @@ function _plot_eccco_mnist( use_class_loss=use_class_loss, nsamples=10, nmin=10, + use_energy_delta=use_energy_delta, ) end @@ -458,6 +471,27 @@ display(plt) savefig(plt, joinpath(output_images_path, "mnist_eccco.png")) ``` +#### Energy Delta (not in paper) + +```{julia} +plt, gen_delta, ces = _plot_eccco_mnist(λ = [0.1,0.1,3.0], use_energy_delta=true) +display(plt) +savefig(plt, joinpath(output_images_path, "mnist_eccco_energy_delta.png")) +``` + +```{julia} +λ_delta = [0.1,0.1,2.5] +λ = [0.1,0.25,0.25] +plts = [] +for i in 0:9 + plt, _, _ = _plot_eccco_mnist(x_factual, i; λ = λ, plot_title="Distance") + plt_delta, _, _ = _plot_eccco_mnist(x_factual, i; λ = λ_delta, use_energy_delta=true, plot_title="Energy Delta") + plt = Plots.plot(plt, plt_delta; size=(img_height*2,img_height), layout=(1,2)) + display(plt) + push!(plts, plt) +end +``` + #### Additional Models (not in paper) LeNet-5: @@ -774,6 +808,29 @@ if _regen_all_digits end ``` +#### Energy Delta (not in paper) + +```{julia} +_regen_all_digits = true +if _regen_all_digits + function plot_all_digits(rng=123;verbose=true,img_height=180,kwargs...) + plts = [] + for i in 0:9 + for j in 0:9 + @info "Generating counterfactual for $(i) -> $(j)" + plt = plot_mnist(i,j;kwargs...,rng=rng, img_height=img_height) + !verbose || display(plt) + plts = [plts..., plt] + end + end + plt = Plots.plot(plts...; size=(img_height*10,img_height*10), layout=(10,10), dpi=300) + return plt + end + plt = plot_all_digits(generator=gen_delta) + savefig(plt, joinpath(output_images_path, "mnist_eccco_all_digits-delta.png")) +end +``` + ## Benchmark ```{julia} @@ -854,4 +911,4 @@ plt = draw( ) display(plt) save(joinpath(output_images_path, "mnist_benchmark.png"), plt, px_per_unit=5) -``` \ No newline at end of file +``` diff --git a/notebooks/prototyping.qmd b/notebooks/prototyping.qmd new file mode 100644 index 0000000000000000000000000000000000000000..7b61b6f8d8f7385e8c60ef52d1a7fb9d5dfb736f --- /dev/null +++ b/notebooks/prototyping.qmd @@ -0,0 +1,130 @@ +```{julia} +include("$(pwd())/notebooks/setup.jl") +eval(setup_notebooks) +``` + +# Linearly Separable Data + +```{julia} +# Hyper: +_retrain = false + +# Data: +test_size = 0.2 +n_obs = Int(1000 / (1.0 - test_size)) +counterfactual_data, test_data = train_test_split( + load_blobs(n_obs; cluster_std=0.1, center_box=(-1. => 1.)); + test_size=test_size +) +X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) +X = table(permutedims(X)) +labels = counterfactual_data.output_encoder.labels +input_dim, n_obs = size(counterfactual_data.X) +output_dim = length(unique(labels)) +``` + +First, let's create a couple of image classifier architectures: + +```{julia} +# Model parameters: +epochs = 100 +batch_size = minimum([Int(round(n_obs/10)), 128]) +n_hidden = 16 +activation = Flux.swish +builder = MLJFlux.MLP( + hidden=(n_hidden, n_hidden, n_hidden), + σ=Flux.swish +) +n_ens = 5 # number of models in ensemble +_loss = Flux.Losses.crossentropy # loss function +_finaliser = Flux.softmax # finaliser function +``` + +```{julia} +# JEM parameters: +ð’Ÿx = Normal() +ð’Ÿy = Categorical(ones(output_dim) ./ output_dim) +sampler = ConditionalSampler( + ð’Ÿx, ð’Ÿy, + input_size=(input_dim,), + batch_size=50, +) +α = [1.0,1.0,1e-1] # penalty strengths +``` + + +```{julia} +# Joint Energy Model: +model = JointEnergyClassifier( + sampler; + builder=builder, + epochs=epochs, + batch_size=batch_size, + finaliser=_finaliser, + loss=_loss, + jem_training_params=( + α=α,verbosity=10, + ), + sampling_steps=30, +) +``` + +```{julia} +conf_model = conformal_model(model; method=:simple_inductive, coverage=0.95) +mach = machine(conf_model, X, labels) +@info "Begin training model." +fit!(mach) +@info "Finished training model." +M = ECCCo.ConformalModel(mach.model, mach.fitresult) +``` + +```{julia} +λ₠= 0.25 +λ₂ = 0.75 +λ₃ = 0.75 +Λ = [λâ‚, λ₂, λ₃] + +opt = Flux.Optimise.Descent(0.01) +use_class_loss = false + +# Benchmark generators: +generator_dict = Dict( + "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss), + "ECCCo (energy delta)" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true), +) +``` + +```{julia} +Random.seed!(2023) + +X = X isa Matrix ? X : Float32.(permutedims(matrix(X))) +factual_label = levels(labels)[1] +x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) +target = levels(labels)[2] +factual = predict_label(M, counterfactual_data, x_factual)[1] + +ces = Dict{Any,Any}() +plts = [] +for (name, generator) in generator_dict + ce = generate_counterfactual( + x_factual, target, counterfactual_data, M, generator; + initialization=:identity, + converge_when=:generator_conditions, + ) + plt = Plots.plot( + ce, title=name, alpha=0.2, + cbar=false, + ) + if contains(name, "ECCCo") + _X = distance_from_energy(ce, return_conditionals=true) + Plots.scatter!( + _X[1,:],_X[2,:], color=:purple, shape=:star5, + ms=10, label="xÌ‚|$target", alpha=0.5 + ) + end + push!(plts, plt) + ces[name] = ce +end +plt = Plots.plot(plts..., size=(800,350)) +display(plt) +``` \ No newline at end of file diff --git a/src/generator.jl b/src/generator.jl index c356fb1aa9c956e30609e943cb99074b785e5136..2119b54407a8bf749a8ea645cc06411860572dda 100644 --- a/src/generator.jl +++ b/src/generator.jl @@ -25,6 +25,7 @@ function ECCCoGenerator(; use_class_loss::Bool=false, nsamples::Int=50, nmin::Int=25, + use_energy_delta::Bool=false, kwargs... ) @@ -47,7 +48,11 @@ function ECCCoGenerator(; # Energy penalty function _energy_penalty(ce::AbstractCounterfactualExplanation) - return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin, kwargs...) + if use_energy_delta + return ECCCo.energy_delta(ce; n=nsamples, nmin=nmin, kwargs...) + else + return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin, kwargs...) + end end _penalties = [Objectives.distance_l1, _set_size_penalty, _energy_penalty] diff --git a/src/penalties.jl b/src/penalties.jl index d8d1b7f669c718ece9cf504b6ffa672787cfc5fe..cba61638c58e6007ce3ca35590fcf1eff5699c03 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -1,4 +1,5 @@ using ChainRules: ignore_derivatives +using CounterfactualExplanations: get_target_index using Distances using Flux using LinearAlgebra: norm @@ -38,6 +39,54 @@ function set_size_penalty( end +function energy_delta( + ce::AbstractCounterfactualExplanation; + n::Int=50, niter=500, from_buffer=true, agg=mean, + choose_lowest_energy=true, + choose_random=false, + nmin::Int=25, + return_conditionals=false, + kwargs... +) + + _loss = 0.0 + nmin = minimum([nmin, n]) + + @assert choose_lowest_energy ⊻ choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither." + + conditional_samples = [] + ignore_derivatives() do + _dict = ce.params + if !(:energy_sampler ∈ collect(keys(_dict))) + _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...) + end + eng_sampler = _dict[:energy_sampler] + if choose_lowest_energy + nmin = minimum([nmin, size(eng_sampler.buffer)[end]]) + xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin) + push!(conditional_samples, xmin) + elseif choose_random + push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer)) + else + push!(conditional_samples, eng_sampler.buffer) + end + end + + xgenerated = conditional_samples[1] # conditional samples + xproposed = CounterfactualExplanations.decode_state(ce) # current state + t = get_target_index(ce.data.y_levels, ce.target) + E(x) = -logits(ce.M, x)[t,:] # negative logits for target class + _loss = E(xproposed) .- E(xgenerated) + + _loss = reduce((x, y) -> x + y, _loss) / n # aggregate over samples + + if return_conditionals + return conditional_samples[1] + end + return _loss + +end + function distance_from_energy( ce::AbstractCounterfactualExplanation; n::Int=50, niter=500, from_buffer=true, agg=mean,