diff --git a/artifacts/results/mnist_classifier.jls b/artifacts/results/mnist_classifier.jls new file mode 100644 index 0000000000000000000000000000000000000000..1a66ab701028818f831939d432d2e82f2efaecae Binary files /dev/null and b/artifacts/results/mnist_classifier.jls differ diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls new file mode 100644 index 0000000000000000000000000000000000000000..cb9f75b93617511d75340cb68ff57b2b3e59cf74 Binary files /dev/null and b/artifacts/results/mnist_vae.jls differ diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls new file mode 100644 index 0000000000000000000000000000000000000000..b8bd092b81cda8bbdb19a3daa4d2974588d7513e Binary files /dev/null and b/artifacts/results/mnist_vae_weak.jls differ diff --git a/notebooks/dev_mnist.qmd b/notebooks/dev_mnist.qmd new file mode 100644 index 0000000000000000000000000000000000000000..9b2b2f6d9a314513725b042c50d1420df9e0c13a --- /dev/null +++ b/notebooks/dev_mnist.qmd @@ -0,0 +1,365 @@ +```{julia} +include("notebooks/setup.jl") +eval(setup_notebooks) +``` + +# MNIST + +```{julia} +function pre_process(x; noise::Float32=0.03f0) + ϵ = Float32.(randn(size(x)) * noise) + x = @.(2 * x - 1) .+ ϵ + return x +end +``` + +```{julia} +# Data: +n_obs = 1000 +counterfactual_data = load_mnist(n_obs) +counterfactual_data.X = pre_process.(counterfactual_data.X) +X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) +X = table(permutedims(X)) +labels = counterfactual_data.output_encoder.labels +input_dim, n_obs = size(counterfactual_data.X) +n_digits = Int(sqrt(input_dim)) +output_dim = length(unique(labels)) +``` + +First, let's create a couple of image classifier architectures: + +```{julia} +# Model parameters: +epochs = 100 +batch_size = minimum([Int(round(n_obs/10)), 128]) +n_hidden = 128 +activation = Flux.relu +builder = MLJFlux.@builder Flux.Chain( + + Dense(n_in, n_hidden, activation), + # Dense(n_hidden, n_hidden, activation), + # Dense(n_hidden, n_hidden, activation), + + # Dense(n_in, n_hidden), + # BatchNorm(n_hidden, activation), + # Dense(n_hidden, n_hidden), + # BatchNorm(n_hidden, activation), + + Dense(n_hidden, n_out), +) +# builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.1, σ=activation) +# builder = MLJFlux.MLP( +# hidden=( +# n_hidden, +# n_hidden, +# n_hidden, +# ), +# σ=activation +# ) +α = [1.0,1.0,5e-3] + +# Simple MLP: +mlp = NeuralNetworkClassifier( + builder=builder, + epochs=epochs, + batch_size=batch_size, + finaliser=x -> x, + loss=Flux.Losses.logitcrossentropy, +) + +# Deep Ensemble: +mlp_ens = EnsembleModel(model=mlp, n=5) + +# Joint Energy Model: +ð’Ÿx = Uniform(-1,1) +ð’Ÿy = Categorical(ones(output_dim) ./ output_dim) +sampler = ConditionalSampler( + ð’Ÿx, ð’Ÿy, + input_size=(input_dim,), + batch_size=1 +) +jem = JointEnergyClassifier( + sampler; + builder=builder, + batch_size=batch_size, + finaliser=x -> x, + loss=Flux.Losses.logitcrossentropy, + jem_training_params=( + α=α,verbosity=10, + # use_gen_loss=false, + # use_reg_loss=false, + ), + sampling_steps=20, + epochs=epochs, +) + +# Deep Ensemble of Joint Energy Models: +jem_ens = EnsembleModel(model=jem, n=5) +``` + +```{julia} +cov = .95 +conf_model = conformal_model(jem_ens; method=:adaptive_inductive, coverage=cov) +mach = machine(conf_model, X, labels) +fit!(mach) +M = ECCCo.ConformalModel(mach.model, mach.fitresult) +``` + +```{julia} +if mach.model.model isa JointEnergyClassifier + sampler = mach.model.model.jem.sampler +else + K = length(counterfactual_data.y_levels) + input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1)) + ð’Ÿx = Uniform(extrema(counterfactual_data.X)...) + ð’Ÿy = Categorical(ones(K) ./ K) + sampler = ConditionalSampler(ð’Ÿx, ð’Ÿy; input_size=input_size) +end +opt = ImproperSGLD() +f(x) = logits(M, x) + +n_iter = 200 +_w = 1500 +plts = [] +neach = 10 +for i in 1:10 + x = sampler(f, opt; niter=n_iter, n_samples=neach, y=i) + plts_i = [] + for j in 1:size(x, 2) + xj = x[:,j] + xj = reshape(xj, (n_digits, n_digits)) + plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)] + end + plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10)) + plts = [plts..., plt] +end +plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1)) +display(plt) + +``` + +```{julia} +test_data = load_mnist_test() +test_data.X = pre_process.(test_data.X) +f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data) +println("F1 score (test): $(round(f1,digits=3))") +``` + +```{julia} +Random.seed!(1234) + +# Set up search: +factual_label = 9 +x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) +target = 7 +factual = predict_label(M, counterfactual_data, x)[1] +γ = 0.9 +T = 100 + +η=1.0 + +# Generate counterfactual using generic generator: +generator = GenericGenerator(opt=Flux.Optimise.Adam(0.01),) +ce_wachter = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, + converge_when=:generator_conditions, +) + +generator = GreedyGenerator(η=η) +ce_jsma = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, + converge_when=:generator_conditions, +) + +# ECCCo: +λ=[0.1,0.1,0.1] +temp=0.1 + +# Generate counterfactual using ECCCo generator: +generator = ECCCoGenerator( + λ=λ, + temp=temp, + opt=Flux.Optimise.Adam(0.01), +) +ce_conformal = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, + converge_when=:generator_conditions, +) + +# Generate counterfactual using ECCCo generator: +generator = ECCCoGenerator( + λ=λ, + temp=temp, + opt=CounterfactualExplanations.Generators.JSMADescent(η=η), +) +ce_conformal_jsma = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, + converge_when=:generator_conditions, +) + +# Plot: +p1 = Plots.plot( + convert2image(MNIST, reshape(x,28,28)), + axis=nothing, + size=(img_height, img_height), + title="Factual" +) +plts = [p1] + +ces = [ce_wachter, ce_conformal, ce_jsma, ce_conformal_jsma] +_names = ["Wachter", "ECCCo", "JSMA", "ECCCo-JSMA"] +for x in zip(ces, _names) + ce, _name = (x[1],x[2]) + x = CounterfactualExplanations.counterfactual(ce) + _phat = target_probs(ce) + _title = "$_name (pÌ‚=$(round(_phat[1]; digits=3)))" + plt = Plots.plot( + convert2image(MNIST, reshape(x,28,28)), + axis=nothing, + size=(img_height, img_height), + title=_title + ) + plts = [plts..., plt] +end +plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) +display(plt) +savefig(plt, joinpath(www_path, "eccco_mnist.png")) +``` + +```{julia} +# Random.seed!(1234) + +# Set up search: +factual_label = 8 +x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) +target = 3 +factual = predict_label(M, counterfactual_data, x)[1] +γ = 0.5 +T = 100 + +# Generate counterfactual using generic generator: +generator = GenericGenerator(opt=Flux.Optimise.Adam(),) +ce_wachter = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, +) + +generator = GreedyGenerator(η=1.0) +ce_jsma = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, +) + +# ECCCo: +λ=[0.0,1.0] +temp=0.5 + +# Generate counterfactual using CCE generator: +generator = CCEGenerator( + λ=λ, + temp=temp, + opt=Flux.Optimise.Adam(), +) +ce_conformal = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, + converge_when=:generator_conditions, +) + +# Generate counterfactual using CCE generator: +generator = CCEGenerator( + λ=λ, + temp=temp, + opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0), +) +ce_conformal_jsma = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, + converge_when=:generator_conditions, +) + +# Plot: +p1 = Plots.plot( + convert2image(MNIST, reshape(x,28,28)), + axis=nothing, + size=(img_height, img_height), + title="Factual" +) +plts = [p1] + +ces = [ce_wachter, ce_conformal, ce_jsma, ce_conformal_jsma] +_names = ["Wachter", "CCE", "JSMA", "CCE-JSMA"] +for x in zip(ces, _names) + ce, _name = (x[1],x[2]) + x = CounterfactualExplanations.counterfactual(ce) + _phat = target_probs(ce) + _title = "$_name (pÌ‚=$(round(_phat[1]; digits=3)))" + plt = Plots.plot( + convert2image(MNIST, reshape(x,28,28)), + axis=nothing, + size=(img_height, img_height), + title=_title + ) + plts = [plts..., plt] +end +plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) +display(plt) +savefig(plt, joinpath(www_path, "cce_mnist.png")) +``` + +```{julia} +if M.model.model isa JointEnergyModels.JointEnergyClassifier + jem = M.model.model.jem + n_iter = 200 + _w = 1500 + plts = [] + neach = 10 + for i in 1:10 + x = jem.sampler(jem.chain, jem.sampling_rule; niter=n_iter, n_samples=neach, y=i) + plts_i = [] + for j in 1:size(x, 2) + xj = x[:,j] + xj = reshape(xj, (n_digits, n_digits)) + plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)] + end + plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10)) + plts = [plts..., plt] + end + plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1)) + display(plt) +end +``` + +## Benchmark + +```{julia} +# Benchmark generators: +generators = Dict( + :wachter => GenericGenerator(opt=opt, λ=l2_λ), + :revise => REVISEGenerator(opt=opt, λ=l2_λ), + :greedy => GreedyGenerator(), +) + +# Conformal Models: + + +# Measures: +measures = [ + CounterfactualExplanations.distance, + ECCCo.distance_from_energy, + ECCCo.distance_from_targets, + CounterfactualExplanations.validity, +] +``` \ No newline at end of file diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index b4efbc521ca2758c7437efb9468b2d3e76b4be43..3ca0ba0d1753be4d56de5da65c9567ba92163bb1 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -5,6 +5,129 @@ eval(setup_notebooks) # MNIST +## Anecdotal Evidence + +### Examples in Introduction + +#### Wachter and JSMA + +```{julia} +# Data: +counterfactual_data = load_mnist() +X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) +input_dim, n_obs = size(counterfactual_data.X) +M = load_mnist_mlp() + +# Target: +factual_label = 8 +x = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) +target = 3 +factual = predict_label(M, counterfactual_data, x)[1] +γ = 0.9 +T = 50 +``` + +```{julia} +# Search: +opt = Flux.Optimise.Adam(0.01) +generator = GenericGenerator(opt=opt) +ce_wachter = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, +) +generator = GreedyGenerator(η=1.0) +ce_jsma = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, +) +``` + +```{julia} +p1 = Plots.plot( + convert2image(MNIST, reshape(x,28,28)), + axis=nothing, + size=(img_height, img_height), + title="Factual" +) +plts = [p1] + +ces = zip([ce_wachter,ce_jsma]) +counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces)) +phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces)) +for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3), ["Wachter","JSMA"]) + ce, _phat, _name = (x[1],x[2],x[3]) + _title = "$(_name) (p=$(round(_phat[1]; digits=2)))" + plt = Plots.plot( + convert2image(MNIST, reshape(ce,28,28)), + axis=nothing, + size=(img_height, img_height), + title=_title + ) + plts = [plts..., plt] +end +plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) +display(plt) +savefig(plt, joinpath(www_path, "you_may_not_like_it.png")) +``` + +#### REVISE + +```{julia} +using CounterfactualExplanations.Models: load_mnist_vae +vae = load_mnist_vae() +vae_weak = load_mnist_vae(;strong=false) +Serialization.serialize(joinpath(output_path,"mnist_classifier.jls"), M) +Serialization.serialize(joinpath(output_path,"mnist_vae.jls"), vae) +Serialization.serialize(joinpath(output_path,"mnist_vae_weak.jls"), vae_weak) +``` + +```{julia} +# Define generator: +generator = REVISEGenerator( + opt = opt, + λ=0.01 +) +# Generate recourse: +counterfactual_data.generative_model = vae # assign generative model +ce_strong = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, +) +counterfactual_data = deepcopy(counterfactual_data) +counterfactual_data.generative_model = vae_weak +ce_weak = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, +) +``` + +```{julia} +ces = zip([ce_strong,ce_weak]) +counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces)) +phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces)) +plts = [p1] +for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3), ["Strong VAE","Weak VAE"]) + ce, _phat, _name = (x[1],x[2],x[3]) + _title = "$(_name) (p=$(round(_phat[1]; digits=2)))" + plt = Plots.plot( + convert2image(MNIST, reshape(ce,28,28)), + axis=nothing, + size=(img_height, img_height), + title=_title + ) + plts = [plts..., plt] +end +plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) +display(plt) +savefig(plt, joinpath(www_path, "surrogate_gone_wrong.png")) +``` + +### ECCo + ```{julia} function pre_process(x; noise::Float32=0.03f0) ϵ = Float32.(randn(size(x)) * noise) @@ -15,8 +138,6 @@ end ```{julia} # Data: -n_obs = 1000 -counterfactual_data = load_mnist(n_obs) counterfactual_data.X = pre_process.(counterfactual_data.X) X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) X = table(permutedims(X)) @@ -35,74 +156,84 @@ batch_size = minimum([Int(round(n_obs/10)), 128]) n_hidden = 128 activation = Flux.relu builder = MLJFlux.@builder Flux.Chain( - Dense(n_in, n_hidden, activation), - # Dense(n_hidden, n_hidden, activation), - # Dense(n_hidden, n_hidden, activation), - - # Dense(n_in, n_hidden), - # BatchNorm(n_hidden, activation), - # Dense(n_hidden, n_hidden), - # BatchNorm(n_hidden, activation), - Dense(n_hidden, n_out), ) -# builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.1, σ=activation) -# builder = MLJFlux.MLP( -# hidden=( -# n_hidden, -# n_hidden, -# n_hidden, -# ), -# σ=activation -# ) -α = [1.0,1.0,5e-3] +n_ens = 5 # number of models in ensemble +_loss = Flux.Losses.logitcrossentropy # loss function +_finaliser = x -> x # finaliser function +``` +```{julia} +# JEM parameters: +ð’Ÿx = Uniform(-1,1) +ð’Ÿy = Categorical(ones(output_dim) ./ output_dim) +sampler = ConditionalSampler( + ð’Ÿx, ð’Ÿy, + input_size=(input_dim,), + batch_size=1 +) +α = [1.0,1.0,1e-2] # penalty strengths +``` + +```{julia} # Simple MLP: mlp = NeuralNetworkClassifier( builder=builder, epochs=epochs, batch_size=batch_size, - finaliser=x -> x, - loss=Flux.Losses.logitcrossentropy, + finaliser=_finaliser, + loss=_loss, ) # Deep Ensemble: -mlp_ens = EnsembleModel(model=mlp, n=5) +mlp_ens = EnsembleModel(model=mlp, n=n_ens) # Joint Energy Model: -ð’Ÿx = Uniform(-1,1) -ð’Ÿy = Categorical(ones(output_dim) ./ output_dim) -sampler = ConditionalSampler( - ð’Ÿx, ð’Ÿy, - input_size=(input_dim,), - batch_size=1 -) jem = JointEnergyClassifier( sampler; builder=builder, + epochs=epochs, batch_size=batch_size, - finaliser=x -> x, - loss=Flux.Losses.logitcrossentropy, + finaliser=_finaliser, + loss=_loss, jem_training_params=( α=α,verbosity=10, - # use_gen_loss=false, - # use_reg_loss=false, ), sampling_steps=20, - epochs=epochs, ) +# JEM with adversarial training: +jem_adv = deepcopy(jem) +# jem_adv.adv_training = true + # Deep Ensemble of Joint Energy Models: -jem_ens = EnsembleModel(model=jem, n=5) +jem_ens = EnsembleModel(model=jem, n=n_ens) + +# Deep Ensemble of Joint Energy Models with adversarial training: +# jem_ens_plus = EnsembleModel(model=jem_adv, n=n_ens) + +# Dictionary of models: +models = Dict( + "MLP" => mlp, + "MLP Ensemble" => mlp_ens, + "JEM" => jem, + "JEM Ensemble" => jem_ens, + # "JEM Ensemble+" => jem_ens_plus, +) ``` + ```{julia} -cov = .95 -conf_model = conformal_model(jem_ens; method=:adaptive_inductive, coverage=cov) -mach = machine(conf_model, X, labels) -fit!(mach) -M = ECCCo.ConformalModel(mach.model, mach.fitresult) +# Train models: +function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive) + conf_model = conformal_model(jem_ens; method=method, coverage=cov) + mach = machine(conf_model, X, y) + fit!(mach) + M = ECCCo.ConformalModel(mach.model, mach.fitresult) + return M +end +model_dict = Dict(mod_name => _train(mod) for (mod_name, mod) in models) ``` ```{julia} @@ -149,9 +280,9 @@ println("F1 score (test): $(round(f1,digits=3))") Random.seed!(1234) # Set up search: -factual_label = 4 +factual_label = 9 x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) -target = 5 +target = 7 factual = predict_label(M, counterfactual_data, x)[1] γ = 0.9 T = 100 @@ -159,7 +290,7 @@ T = 100 η=1.0 # Generate counterfactual using generic generator: -generator = GenericGenerator() +generator = GenericGenerator(opt=Flux.Optimise.Adam(0.01),) ce_wachter = generate_counterfactual( x, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, diff --git a/paper/paper.pdf b/paper/paper.pdf index 56e5e6bc610e19203b81094a4d550becf73727cd..d59b8d6eb51bd67c8b0070adfbd3bf6244bdcd2f 100644 Binary files a/paper/paper.pdf and b/paper/paper.pdf differ diff --git a/paper/paper.tex b/paper/paper.tex index 3f271797ba751480aa2f429098c24c45c55b9aa3..f4bd6b8ad15a1921a1426800f8b795fbf0afd73c 100644 --- a/paper/paper.tex +++ b/paper/paper.tex @@ -96,7 +96,7 @@ \begin{abstract} - We propose Conformal Counterfactual Explanations: an effortless and rigorous way to produce plausible and conformal Counterfactual Explanations for Black Box Models using Conformal Prediction. To address the need for plausible explanations, existing work has primarily relied on surrogate models to learn the data-generating process. This effectively reallocates the task of learning realistic representations of the data from the model itself to the surrogate. Consequently, the generated explanations may look plausible to humans but not necessarily conform with the behaviour of the Black Box Model. We formalise this notion through the introduction of new evaluation measures. In order to still address the need for plausibility, we build on a recent approach that works by minimizing predictive model uncertainty. Using differentiable Conformal Prediction, we relax the previous assumption that the Black Box Model can produce predictive uncertainty estimates. + We propose Eccoe: an effortless and rigorous way to produce plausible and conformal Counterfactual Explanations for Black Box Models using Conformal Prediction. To address the need for plausible explanations, existing work has primarily relied on surrogate models to learn the data-generating process. This effectively reallocates the task of learning realistic representations of the data from the model itself to the surrogate. Consequently, the generated explanations may look plausible to humans but not necessarily conform with the behaviour of the Black Box Model. We formalise this notion through the introduction of new evaluation measures. In order to still address the need for plausibility, we build on a recent approach that works by minimizing predictive model uncertainty. Using differentiable Conformal Prediction, we relax the previous assumption that the Black Box Model can produce predictive uncertainty estimates. \end{abstract} \section{Introduction}\label{intro} @@ -266,9 +266,9 @@ The fact that conformal classifiers produce set-valued predictions introduces a where $\kappa \in \{0,1\}$ is a hyper-parameter and $C_{\theta,\mathbf{y}}(\mathbf{x}_i;\alpha)$ can be interpreted as the probability of label $\mathbf{y}$ being included in the prediction set. Formally, it is defined as $C_{\theta,\mathbf{y}}(\mathbf{x}_i;\alpha):=\sigma\left((s(\mathbf{x}_i,\mathbf{y})-\alpha) T^{-1}\right)$ for $\mathbf{y}\in\mathcal{Y}$ where $\sigma$ is the sigmoid function and $T$ is a hyper-parameter used for temperature scaling \citep{stutz2022learning}. -Penalizing the set size in this way is in principal enough to train efficient conformal classifiers \citep{stutz2022learning}. As we explained above, the set size is also closely linked to predictive uncertainty at the local level. This makes the smooth penalty defined in Equation~\ref{eq:setsize} useful in the context of meeting our objective of generating plausible counterfactuals. In particular, we adapt Equation~\ref{eq:general} to define the baseline objective for Conformal Counterfactual Explanations (ECCCo): +Penalizing the set size in this way is in principle enough to train efficient conformal classifiers \citep{stutz2022learning}. As we explained above, the set size is also closely linked to predictive uncertainty at the local level. This makes the smooth penalty defined in Equation~\ref{eq:setsize} useful in the context of meeting our objective of generating plausible counterfactuals. In particular, we adapt Equation~\ref{eq:general} to define the objective for our proposed Energy-Constrained Conformal Counterfactual Explanations (Eccoe): -\begin{equation}\label{eq:cce} +\begin{equation}\label{eq:eccoe} \begin{aligned} \mathbf{Z}^\prime &= \arg \min_{\mathbf{Z}^\prime \in \mathcal{Z}^M} \left\{ {\text{yloss}(M_{\theta}(f(\mathbf{Z}^\prime)),\mathbf{y}^*)}+ \lambda \Omega(C_{\theta}(f(\mathbf{Z}^\prime);\alpha)) \right\} \end{aligned} diff --git a/www/eccco_mnist.png b/www/eccco_mnist.png index 8b0a8a14d08b2310cfafe1ce68f55891d6ba1a33..c561119c486e826ba460c2dbfa4eeb2acbf52b31 100644 Binary files a/www/eccco_mnist.png and b/www/eccco_mnist.png differ diff --git a/www/surrogate_gone_wrong.png b/www/surrogate_gone_wrong.png index ac97976b41383c1e0722c5bddaaabc739499c013..6808922c305687ab54508ceec9a770db30dae508 100644 Binary files a/www/surrogate_gone_wrong.png and b/www/surrogate_gone_wrong.png differ diff --git a/www/you_may_not_like_it.png b/www/you_may_not_like_it.png index 8e30bcfbe8aeaa402dd8aaa452d2b63f396b80e5..379a50da7f1b073e27a70eac5f385e5b3218e0be 100644 Binary files a/www/you_may_not_like_it.png and b/www/you_may_not_like_it.png differ