From 335f7133f0f0515686120db93f0eb3bcba1c39f9 Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Thu, 10 Aug 2023 11:52:24 +0200 Subject: [PATCH] :fire: --- notebooks/.CondaPkg/env/conda-meta/history | 2 +- notebooks/.CondaPkg/meta | Bin 753 -> 753 bytes notebooks/prototyping.qmd | 130 +++++++++++++++++++++ src/generator.jl | 7 +- src/penalties.jl | 47 ++++++++ 5 files changed, 184 insertions(+), 2 deletions(-) create mode 100644 notebooks/prototyping.qmd diff --git a/notebooks/.CondaPkg/env/conda-meta/history b/notebooks/.CondaPkg/env/conda-meta/history index dd869616..6c7683b9 100644 --- a/notebooks/.CondaPkg/env/conda-meta/history +++ b/notebooks/.CondaPkg/env/conda-meta/history @@ -1,4 +1,4 @@ -==> 2023-08-08 15:03:23 <== +==> 2023-08-10 11:22:12 <== # cmd: /Users/FA31DU/.julia/artifacts/6ecf04294c7f327e02e84972f34835649a5eb35e/bin/micromamba -r /Users/FA31DU/.julia/scratchspaces/0b3b1443-0f03-428d-bdfb-f27f9c1191ea/root create -y -p /Users/FA31DU/code/ECCCo.jl/notebooks/.CondaPkg/env --override-channels --no-channel-priority numpy[version='*'] pip[version='>=22.0.0'] python[version='>=3.7,<4',channel='conda-forge',build='*cpython*'] -c conda-forge # conda version: 3.8.0 +https://conda.anaconda.org/conda-forge/osx-64::xz-5.2.6-h775f41a_0 diff --git a/notebooks/.CondaPkg/meta b/notebooks/.CondaPkg/meta index 1b11b47e6fbbf80a0ee54e97be26195177f40138..6b7ba5c59ae065f773611edc2fc41659836fb772 100644 GIT binary patch delta 23 acmey!`jM57hXDd!sz)non{MQL%me^IFa?(Y delta 23 acmey!`jM57hXDdI8RtHHX0nm*F%tkrhX$<x diff --git a/notebooks/prototyping.qmd b/notebooks/prototyping.qmd new file mode 100644 index 00000000..7b61b6f8 --- /dev/null +++ b/notebooks/prototyping.qmd @@ -0,0 +1,130 @@ +```{julia} +include("$(pwd())/notebooks/setup.jl") +eval(setup_notebooks) +``` + +# Linearly Separable Data + +```{julia} +# Hyper: +_retrain = false + +# Data: +test_size = 0.2 +n_obs = Int(1000 / (1.0 - test_size)) +counterfactual_data, test_data = train_test_split( + load_blobs(n_obs; cluster_std=0.1, center_box=(-1. => 1.)); + test_size=test_size +) +X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) +X = table(permutedims(X)) +labels = counterfactual_data.output_encoder.labels +input_dim, n_obs = size(counterfactual_data.X) +output_dim = length(unique(labels)) +``` + +First, let's create a couple of image classifier architectures: + +```{julia} +# Model parameters: +epochs = 100 +batch_size = minimum([Int(round(n_obs/10)), 128]) +n_hidden = 16 +activation = Flux.swish +builder = MLJFlux.MLP( + hidden=(n_hidden, n_hidden, n_hidden), + σ=Flux.swish +) +n_ens = 5 # number of models in ensemble +_loss = Flux.Losses.crossentropy # loss function +_finaliser = Flux.softmax # finaliser function +``` + +```{julia} +# JEM parameters: +ð’Ÿx = Normal() +ð’Ÿy = Categorical(ones(output_dim) ./ output_dim) +sampler = ConditionalSampler( + ð’Ÿx, ð’Ÿy, + input_size=(input_dim,), + batch_size=50, +) +α = [1.0,1.0,1e-1] # penalty strengths +``` + + +```{julia} +# Joint Energy Model: +model = JointEnergyClassifier( + sampler; + builder=builder, + epochs=epochs, + batch_size=batch_size, + finaliser=_finaliser, + loss=_loss, + jem_training_params=( + α=α,verbosity=10, + ), + sampling_steps=30, +) +``` + +```{julia} +conf_model = conformal_model(model; method=:simple_inductive, coverage=0.95) +mach = machine(conf_model, X, labels) +@info "Begin training model." +fit!(mach) +@info "Finished training model." +M = ECCCo.ConformalModel(mach.model, mach.fitresult) +``` + +```{julia} +λ₠= 0.25 +λ₂ = 0.75 +λ₃ = 0.75 +Λ = [λâ‚, λ₂, λ₃] + +opt = Flux.Optimise.Descent(0.01) +use_class_loss = false + +# Benchmark generators: +generator_dict = Dict( + "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss), + "ECCCo (energy delta)" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true), +) +``` + +```{julia} +Random.seed!(2023) + +X = X isa Matrix ? X : Float32.(permutedims(matrix(X))) +factual_label = levels(labels)[1] +x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) +target = levels(labels)[2] +factual = predict_label(M, counterfactual_data, x_factual)[1] + +ces = Dict{Any,Any}() +plts = [] +for (name, generator) in generator_dict + ce = generate_counterfactual( + x_factual, target, counterfactual_data, M, generator; + initialization=:identity, + converge_when=:generator_conditions, + ) + plt = Plots.plot( + ce, title=name, alpha=0.2, + cbar=false, + ) + if contains(name, "ECCCo") + _X = distance_from_energy(ce, return_conditionals=true) + Plots.scatter!( + _X[1,:],_X[2,:], color=:purple, shape=:star5, + ms=10, label="xÌ‚|$target", alpha=0.5 + ) + end + push!(plts, plt) + ces[name] = ce +end +plt = Plots.plot(plts..., size=(800,350)) +display(plt) +``` \ No newline at end of file diff --git a/src/generator.jl b/src/generator.jl index c356fb1a..2119b544 100644 --- a/src/generator.jl +++ b/src/generator.jl @@ -25,6 +25,7 @@ function ECCCoGenerator(; use_class_loss::Bool=false, nsamples::Int=50, nmin::Int=25, + use_energy_delta::Bool=false, kwargs... ) @@ -47,7 +48,11 @@ function ECCCoGenerator(; # Energy penalty function _energy_penalty(ce::AbstractCounterfactualExplanation) - return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin, kwargs...) + if use_energy_delta + return ECCCo.energy_delta(ce; n=nsamples, nmin=nmin, kwargs...) + else + return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin, kwargs...) + end end _penalties = [Objectives.distance_l1, _set_size_penalty, _energy_penalty] diff --git a/src/penalties.jl b/src/penalties.jl index d8d1b7f6..b42a7071 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -38,6 +38,53 @@ function set_size_penalty( end +function energy_delta( + ce::AbstractCounterfactualExplanation; + n::Int=50, niter=500, from_buffer=true, agg=mean, + choose_lowest_energy=true, + choose_random=false, + nmin::Int=25, + return_conditionals=false, + kwargs... +) + + _loss = 0.0 + nmin = minimum([nmin, n]) + + @assert choose_lowest_energy ⊻ choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither." + + conditional_samples = [] + ignore_derivatives() do + _dict = ce.params + if !(:energy_sampler ∈ collect(keys(_dict))) + _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...) + end + eng_sampler = _dict[:energy_sampler] + if choose_lowest_energy + nmin = minimum([nmin, size(eng_sampler.buffer)[end]]) + xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin) + push!(conditional_samples, xmin) + elseif choose_random + push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer)) + else + push!(conditional_samples, eng_sampler.buffer) + end + end + + xtarget = conditional_samples[1] # conditional samples + x = CounterfactualExplanations.decode_state(ce) # current state + E(x) = -logits(ce.M, x)[ce.target,:] # negative logits for target class + _loss = E(x) .- E(xtarget) + + _loss = reduce((x, y) -> x + y, _loss) / n # aggregate over samples + + if return_conditionals + return conditional_samples[1] + end + return _loss + +end + function distance_from_energy( ce::AbstractCounterfactualExplanation; n::Int=50, niter=500, from_buffer=true, agg=mean, -- GitLab