diff --git a/artifacts/results/mnist_architectures.jls b/artifacts/results/mnist_architectures.jls index feebaaeb1db83e61382546fd8ab2ba5709fb7952..166c2de99758134ab6d45178a62742dc11b65871 100644 Binary files a/artifacts/results/mnist_architectures.jls and b/artifacts/results/mnist_architectures.jls differ diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls index 7cbb69e8a07f4f740d8accdcb74d978345430631..6bd916f99fd0f8046def3082ec0ddda04f70474c 100644 Binary files a/artifacts/results/mnist_vae.jls and b/artifacts/results/mnist_vae.jls differ diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls index 5c9ce01b84418239fae309662e3695b302bdc6e1..43582e6b78a04ca0144dd911bc4b377335c883cd 100644 Binary files a/artifacts/results/mnist_vae_weak.jls and b/artifacts/results/mnist_vae_weak.jls differ diff --git a/notebooks/cal_housing.qmd b/notebooks/cal_housing.qmd deleted file mode 100644 index 4b130aa5bc9ff0e10fb22c8ea4362f61cd426f99..0000000000000000000000000000000000000000 --- a/notebooks/cal_housing.qmd +++ /dev/null @@ -1,209 +0,0 @@ -```{julia} -include("$(pwd())/notebooks/setup.jl") -eval(setup_notebooks) -``` - -# Real-World Data - -```{julia} -# Hyper: -_retrain = true - -# Data: -test_size = 0.2 -n_obs = Int(10000 / (1.0 - test_size)) -counterfactual_data, test_data = train_test_split(load_california_housing(n_obs); test_size=test_size) -X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) -X = table(permutedims(X)) -labels = counterfactual_data.output_encoder.labels -input_dim, n_obs = size(counterfactual_data.X) -output_dim = length(unique(labels)) -``` - -First, let's create a couple of image classifier architectures: - -```{julia} -# Model parameters: -epochs = 100 -batch_size = minimum([Int(round(n_obs/10)), 128]) -n_hidden = 128 -activation = Flux.relu -builder = MLJFlux.@builder Flux.Chain( - Dense(n_in, n_hidden, activation), - Dense(n_hidden, n_hidden, activation), - Dense(n_hidden, n_out), -) -n_ens = 5 # number of models in ensemble -_loss = Flux.Losses.logitcrossentropy # loss function -_finaliser = x -> x # finaliser function -``` - -```{julia} -# JEM parameters: -ð’Ÿx = Normal() -ð’Ÿy = Categorical(ones(output_dim) ./ output_dim) -sampler = ConditionalSampler( - ð’Ÿx, ð’Ÿy, - input_size=(input_dim,), - batch_size=10, -) -α = [1.0,1.0,1e-1] # penalty strengths -``` - -```{julia} -# Simple MLP: -mlp = NeuralNetworkClassifier( - builder=builder, - epochs=epochs, - batch_size=batch_size, - finaliser=_finaliser, - loss=_loss, -) - -# Deep Ensemble: -mlp_ens = EnsembleModel(model=mlp, n=n_ens) - -# Joint Energy Model: -jem = JointEnergyClassifier( - sampler; - builder=builder, - epochs=epochs, - batch_size=batch_size, - finaliser=_finaliser, - loss=_loss, - jem_training_params=( - α=α,verbosity=10, - ), - sampling_steps=30, -) - -# JEM with adversarial training: -jem_adv = deepcopy(jem) -# jem_adv.adv_training = true - -# Deep Ensemble of Joint Energy Models: -jem_ens = EnsembleModel(model=jem, n=n_ens) - -# Deep Ensemble of Joint Energy Models with adversarial training: -# jem_ens_plus = EnsembleModel(model=jem_adv, n=n_ens) - -# Dictionary of models: -models = Dict( - "MLP" => mlp, - "MLP Ensemble" => mlp_ens, - "JEM" => jem, - "JEM Ensemble" => jem_ens, - # "JEM Ensemble+" => jem_ens_plus, -) -``` - - -```{julia} -# Train models: -function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_name="model") - conf_model = conformal_model(model; method=method, coverage=cov) - mach = machine(conf_model, X, y) - @info "Begin training $mod_name." - fit!(mach) - @info "Finished training $mod_name." - M = ECCCo.ConformalModel(mach.model, mach.fitresult) - return M -end -if _retrain - model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in models) - Serialization.serialize(joinpath(output_path,"cal_housing_models.jls"), model_dict) -else - model_dict = Serialization.deserialize(joinpath(output_path,"cal_housing_models.jls")) -end -``` - -```{julia} -# Evaluate models: - -measure = Dict( - :f1score => multiclass_f1score, - :acc => accuracy, - :precision => multiclass_precision -) -model_performance = DataFrame() -for (mod_name, mod) in model_dict - # Test performance: - _perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure))) - _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) - _perf.mod_name .= mod_name - model_performance = vcat(model_performance, _perf) -end -Serialization.serialize(joinpath(output_path,"cal_housing_model_performance.jls"), model_performance) -CSV.write(joinpath(output_path, "cal_housing_model_performance.csv"), model_performance) -model_performance -``` - -## Benchmark - -```{julia} -λ₠= 0.25 -λ₂ = 0.75 -λ₃ = 0.75 -Λ = [λâ‚, λ₂, λ₃] - -# Benchmark generators: -generator_dict = Dict( - "Wachter" => WachterGenerator(), - "REVISE" => REVISEGenerator(), - "Schut" => GreedyGenerator(), - "ECCCo" => ECCCoGenerator(λ=Λ), -) -``` - -```{julia} -# Measures: -measures = [ - CounterfactualExplanations.distance, - ECCCo.distance_from_energy, - ECCCo.distance_from_targets, - CounterfactualExplanations.Evaluation.validity, - CounterfactualExplanations.Evaluation.redundancy, - ECCCo.set_size_penalty -] - -bmks = [] -for target in sort(unique(labels)) - for factual in sort(unique(labels)) - if factual == target - continue - end - bmk = benchmark( - counterfactual_data; - models=model_dict, - generators=generator_dict, - measure=measures, - suppress_training=true, dataname="California Housing", - n_individuals=10, - target=target, factual=factual, - initialization=:identity, - converge_when=:generator_conditions, - ) - push!(bmks, bmk) - end -end -bmk = reduce(vcat, bmks) -CSV.write(joinpath(output_path, "cal_housing_benchmark.csv"), bmk()) -``` - -```{julia} -df = @chain bmk() begin - @mutate(variable = ifelse.(variable .== "distance_from_energy", "Non-Conformity", variable)) - @mutate(variable = ifelse.(variable .== "distance_from_targets", "Implausibility", variable)) - @mutate(variable = ifelse.(variable .== "distance", "Cost", variable)) - @mutate(variable = ifelse.(variable .== "redundancy", "Redundancy", variable)) - @mutate(variable = ifelse.(variable .== "Validity", "Validity", variable)) -end -plt = AlgebraOfGraphics.data(df) * visual(BoxPlot) * - mapping(:generator, :value, row=:variable, col=:model, color=:generator) -plt = draw( - plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=150, height=120), - facet=(; linkyaxes=:none) -) -display(plt) -save(joinpath(output_images_path, "cal_housing_benchmark.png"), plt, px_per_unit=5) -``` \ No newline at end of file diff --git a/notebooks/circles.qmd b/notebooks/circles.qmd index 672ce5429c4d1dd7637388a5c432732d1353a837..a71667d1b20319e9ffc51b7f5afb8b7e12b99202 100644 --- a/notebooks/circles.qmd +++ b/notebooks/circles.qmd @@ -96,6 +96,24 @@ models = Dict( ) ``` +```{julia} +params = DataFrame( + Dict( + :n_obs => Int.(round(n_obs/10)*10), + :epochs => epochs, + :batch_size => batch_size, + :n_hidden => n_hidden, + :n_layers => length(builder.hidden), + :activation => string(activation), + :n_ens => n_ens, + :lambda => string(α[3]), + :jem_sampling_steps => jem.sampling_steps, + :sgld_batch_size => sampler.batch_size, + :dataname => "Circles", + ) +) +CSV.write(joinpath(params_path, "circles.csv"), params) +``` ```{julia} # Train models: diff --git a/notebooks/dev_mnist.qmd b/notebooks/dev_mnist.qmd deleted file mode 100644 index 910e798844a326a66e587e650a49aa2c3fc644cb..0000000000000000000000000000000000000000 --- a/notebooks/dev_mnist.qmd +++ /dev/null @@ -1,365 +0,0 @@ -```{julia} -include("$(pwd())/notebooks/setup.jl") -eval(setup_notebooks) -``` - -# MNIST - -```{julia} -function pre_process(x; noise::Float32=0.03f0) - ϵ = Float32.(randn(size(x)) * noise) - x = @.(2 * x - 1) .+ ϵ - return x -end -``` - -```{julia} -# Data: -n_obs = 1000 -counterfactual_data = load_mnist(n_obs) -counterfactual_data.X = pre_process.(counterfactual_data.X) -X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) -X = table(permutedims(X)) -labels = counterfactual_data.output_encoder.labels -input_dim, n_obs = size(counterfactual_data.X) -n_digits = Int(sqrt(input_dim)) -output_dim = length(unique(labels)) -``` - -First, let's create a couple of image classifier architectures: - -```{julia} -# Model parameters: -epochs = 100 -batch_size = minimum([Int(round(n_obs/10)), 128]) -n_hidden = 128 -activation = Flux.relu -builder = MLJFlux.@builder Flux.Chain( - - Dense(n_in, n_hidden, activation), - # Dense(n_hidden, n_hidden, activation), - # Dense(n_hidden, n_hidden, activation), - - # Dense(n_in, n_hidden), - # BatchNorm(n_hidden, activation), - # Dense(n_hidden, n_hidden), - # BatchNorm(n_hidden, activation), - - Dense(n_hidden, n_out), -) -# builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.1, σ=activation) -# builder = MLJFlux.MLP( -# hidden=( -# n_hidden, -# n_hidden, -# n_hidden, -# ), -# σ=activation -# ) -α = [1.0,1.0,5e-3] - -# Simple MLP: -mlp = NeuralNetworkClassifier( - builder=builder, - epochs=epochs, - batch_size=batch_size, - finaliser=x -> x, - loss=Flux.Losses.logitcrossentropy, -) - -# Deep Ensemble: -mlp_ens = EnsembleModel(model=mlp, n=5) - -# Joint Energy Model: -ð’Ÿx = Uniform(-1,1) -ð’Ÿy = Categorical(ones(output_dim) ./ output_dim) -sampler = ConditionalSampler( - ð’Ÿx, ð’Ÿy, - input_size=(input_dim,), - batch_size=1 -) -jem = JointEnergyClassifier( - sampler; - builder=builder, - batch_size=batch_size, - finaliser=x -> x, - loss=Flux.Losses.logitcrossentropy, - jem_training_params=( - α=α,verbosity=10, - # use_gen_loss=false, - # use_reg_loss=false, - ), - sampling_steps=20, - epochs=epochs, -) - -# Deep Ensemble of Joint Energy Models: -jem_ens = EnsembleModel(model=jem, n=5) -``` - -```{julia} -cov = .95 -conf_model = conformal_model(jem_ens; method=:adaptive_inductive, coverage=cov) -mach = machine(conf_model, X, labels) -fit!(mach) -M = ECCCo.ConformalModel(mach.model, mach.fitresult) -``` - -```{julia} -if mach.model.model isa JointEnergyClassifier - sampler = mach.model.model.jem.sampler -else - K = length(counterfactual_data.y_levels) - input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1)) - ð’Ÿx = Uniform(extrema(counterfactual_data.X)...) - ð’Ÿy = Categorical(ones(K) ./ K) - sampler = ConditionalSampler(ð’Ÿx, ð’Ÿy; input_size=input_size) -end -opt = ImproperSGLD() -f(x) = logits(M, x) - -n_iter = 200 -_w = 1500 -plts = [] -neach = 10 -for i in 1:10 - x = sampler(f, opt; niter=n_iter, n_samples=neach, y=i) - plts_i = [] - for j in 1:size(x, 2) - xj = x[:,j] - xj = reshape(xj, (n_digits, n_digits)) - plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)] - end - plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10)) - plts = [plts..., plt] -end -plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1)) -display(plt) - -``` - -```{julia} -test_data = load_mnist_test() -test_data.X = pre_process.(test_data.X) -f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data) -println("F1 score (test): $(round(f1,digits=3))") -``` - -```{julia} -Random.seed!(1234) - -# Set up search: -factual_label = 9 -x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) -target = 7 -factual = predict_label(M, counterfactual_data, x)[1] -γ = 0.9 -T = 100 - -η=1.0 - -# Generate counterfactual using generic generator: -generator = GenericGenerator(opt=Flux.Optimise.Adam(0.01),) -ce_wachter = generate_counterfactual( - x, target, counterfactual_data, M, generator; - decision_threshold=γ, max_iter=T, - initialization=:identity, - converge_when=:generator_conditions, -) - -generator = GreedyGenerator(η=η) -ce_jsma = generate_counterfactual( - x, target, counterfactual_data, M, generator; - decision_threshold=γ, max_iter=T, - initialization=:identity, - converge_when=:generator_conditions, -) - -# ECCCo: -λ=[0.1,0.1,0.1] -temp=0.1 - -# Generate counterfactual using ECCCo generator: -generator = ECCCoGenerator( - λ=λ, - temp=temp, - opt=Flux.Optimise.Adam(0.01), -) -ce_conformal = generate_counterfactual( - x, target, counterfactual_data, M, generator; - decision_threshold=γ, max_iter=T, - initialization=:identity, - converge_when=:generator_conditions, -) - -# Generate counterfactual using ECCCo generator: -generator = ECCCoGenerator( - λ=λ, - temp=temp, - opt=CounterfactualExplanations.Generators.JSMADescent(η=η), -) -ce_conformal_jsma = generate_counterfactual( - x, target, counterfactual_data, M, generator; - decision_threshold=γ, max_iter=T, - initialization=:identity, - converge_when=:generator_conditions, -) - -# Plot: -p1 = Plots.plot( - convert2image(MNIST, reshape(x,28,28)), - axis=nothing, - size=(img_height, img_height), - title="Factual" -) -plts = [p1] - -ces = [ce_wachter, ce_conformal, ce_jsma, ce_conformal_jsma] -_names = ["Wachter", "ECCCo", "JSMA", "ECCCo-JSMA"] -for x in zip(ces, _names) - ce, _name = (x[1],x[2]) - x = CounterfactualExplanations.counterfactual(ce) - _phat = target_probs(ce) - _title = "$_name (pÌ‚=$(round(_phat[1]; digits=3)))" - plt = Plots.plot( - convert2image(MNIST, reshape(x,28,28)), - axis=nothing, - size=(img_height, img_height), - title=_title - ) - plts = [plts..., plt] -end -plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) -display(plt) -savefig(plt, joinpath(www_path, "eccco_mnist.png")) -``` - -```{julia} -# Random.seed!(1234) - -# Set up search: -factual_label = 8 -x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) -target = 3 -factual = predict_label(M, counterfactual_data, x)[1] -γ = 0.5 -T = 100 - -# Generate counterfactual using generic generator: -generator = GenericGenerator(opt=Flux.Optimise.Adam(),) -ce_wachter = generate_counterfactual( - x, target, counterfactual_data, M, generator; - decision_threshold=γ, max_iter=T, - initialization=:identity, -) - -generator = GreedyGenerator(η=1.0) -ce_jsma = generate_counterfactual( - x, target, counterfactual_data, M, generator; - decision_threshold=γ, max_iter=T, - initialization=:identity, -) - -# ECCCo: -λ=[0.0,1.0] -temp=0.5 - -# Generate counterfactual using CCE generator: -generator = CCEGenerator( - λ=λ, - temp=temp, - opt=Flux.Optimise.Adam(), -) -ce_conformal = generate_counterfactual( - x, target, counterfactual_data, M, generator; - decision_threshold=γ, max_iter=T, - initialization=:identity, - converge_when=:generator_conditions, -) - -# Generate counterfactual using CCE generator: -generator = CCEGenerator( - λ=λ, - temp=temp, - opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0), -) -ce_conformal_jsma = generate_counterfactual( - x, target, counterfactual_data, M, generator; - decision_threshold=γ, max_iter=T, - initialization=:identity, - converge_when=:generator_conditions, -) - -# Plot: -p1 = Plots.plot( - convert2image(MNIST, reshape(x,28,28)), - axis=nothing, - size=(img_height, img_height), - title="Factual" -) -plts = [p1] - -ces = [ce_wachter, ce_conformal, ce_jsma, ce_conformal_jsma] -_names = ["Wachter", "CCE", "JSMA", "CCE-JSMA"] -for x in zip(ces, _names) - ce, _name = (x[1],x[2]) - x = CounterfactualExplanations.counterfactual(ce) - _phat = target_probs(ce) - _title = "$_name (pÌ‚=$(round(_phat[1]; digits=3)))" - plt = Plots.plot( - convert2image(MNIST, reshape(x,28,28)), - axis=nothing, - size=(img_height, img_height), - title=_title - ) - plts = [plts..., plt] -end -plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) -display(plt) -savefig(plt, joinpath(www_path, "cce_mnist.png")) -``` - -```{julia} -if M.model.model isa JointEnergyModels.JointEnergyClassifier - jem = M.model.model.jem - n_iter = 200 - _w = 1500 - plts = [] - neach = 10 - for i in 1:10 - x = jem.sampler(jem.chain, jem.sampling_rule; niter=n_iter, n_samples=neach, y=i) - plts_i = [] - for j in 1:size(x, 2) - xj = x[:,j] - xj = reshape(xj, (n_digits, n_digits)) - plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)] - end - plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10)) - plts = [plts..., plt] - end - plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1)) - display(plt) -end -``` - -## Benchmark - -```{julia} -# Benchmark generators: -generators = Dict( - :wachter => GenericGenerator(opt=opt, λ=l2_λ), - :revise => REVISEGenerator(opt=opt, λ=l2_λ), - :greedy => GreedyGenerator(), -) - -# Conformal Models: - - -# Measures: -measures = [ - CounterfactualExplanations.distance, - ECCCo.distance_from_energy, - ECCCo.distance_from_targets, - CounterfactualExplanations.validity, -] -``` \ No newline at end of file diff --git a/notebooks/fashion_mnist.qmd b/notebooks/fashion_mnist.qmd deleted file mode 100644 index 0f318b06350ba13a68e840b501cadfb78af05eb2..0000000000000000000000000000000000000000 --- a/notebooks/fashion_mnist.qmd +++ /dev/null @@ -1,533 +0,0 @@ -```{julia} -include("$(pwd())/notebooks/setup.jl") -eval(setup_notebooks) -``` - -# FashionMNIST - -## Anecdotal Evidence - -### Examples in Introduction - -#### Wachter and JSMA - -```{julia} -Random.seed!(2023) - -# Data: -counterfactual_data = load_fashion_mnist() -X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) -input_dim, n_obs = size(counterfactual_data.X) -M = load_fashion_mnist_mlp() - -# Target: -factual_label = 9 -x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) -target = 7 -factual = predict_label(M, counterfactual_data, x_factual)[1] -γ = 0.9 - -# Training params: -T = 100 -``` - -```{julia} -# Search: -generic_generator = WachterGenerator() -ce_wachter = generate_counterfactual( - x_factual, target, counterfactual_data, M, generic_generator; - decision_threshold=γ, max_iter=T, - initialization=:identity, -) -greedy_generator = GreedyGenerator(η=2.0) -ce_jsma = generate_counterfactual( - x_factual, target, counterfactual_data, M, greedy_generator; - decision_threshold=γ, max_iter=T, - initialization=:identity, -) -``` - -```{julia} -p1 = Plots.plot( - convert2image(FashionMNIST, reshape(x_factual,28,28)), - axis=nothing, - size=(img_height, img_height), - title="Factual" -) -plts = [p1] - -ces = zip([ce_wachter,ce_jsma]) -counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces)) -phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces)) -for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3), ["Wachter","JSMA"]) - ce, _phat, _name = (x[1],x[2],x[3]) - _title = "$(_name) (p=$(round(_phat[1]; digits=2)))" - plt = Plots.plot( - convert2image(FashionMNIST, reshape(ce,28,28)), - axis=nothing, - size=(img_height, img_height), - title=_title - ) - plts = [plts..., plt] -end -plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) -display(plt) -savefig(plt, joinpath(output_images_path, "fashion_you_may_not_like_it.png")) -``` - -#### REVISE - -```{julia} -using CounterfactualExplanations.Models: load_fashion_mnist_vae -vae = load_fashion_mnist_vae() -vae_weak = load_fashion_mnist_vae(;strong=false) -Serialization.serialize(joinpath(output_path,"fashion_mnist_classifier.jls"), M) -Serialization.serialize(joinpath(output_path,"fashion_mnist_vae.jls"), vae) -Serialization.serialize(joinpath(output_path,"fashion_mnist_vae_weak.jls"), vae_weak) -``` - -```{julia} -# Define generator: -revise_generator = REVISEGenerator( - opt = Flux.Optimise.Descent(0.25), - λ=0.0, -) -# Generate recourse: -counterfactual_data.generative_model = vae # assign generative model -ce_strong = generate_counterfactual( - x_factual, target, counterfactual_data, M, revise_generator; - decision_threshold=γ, max_iter=T, - initialization=:identity, - converge_when=:generator_conditions, -) -counterfactual_data_weak = deepcopy(counterfactual_data) -counterfactual_data_weak.generative_model = vae_weak -ce_weak = generate_counterfactual( - x_factual, target, counterfactual_data_weak, M, revise_generator; - decision_threshold=γ, max_iter=T, - initialization=:identity, - converge_when=:generator_conditions, -) -``` - -```{julia} -ces = zip([ce_strong,ce_weak]) -counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces)) -phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces)) -plts = [p1] -for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3), ["Strong VAE","Weak VAE"]) - ce, _phat, _name = (x[1],x[2],x[3]) - _title = "$(_name) (p=$(round(_phat[1]; digits=2)))" - plt = Plots.plot( - convert2image(FashionMNIST, reshape(ce,28,28)), - axis=nothing, - size=(img_height, img_height), - title=_title - ) - plts = [plts..., plt] -end -plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) -display(plt) -savefig(plt, joinpath(output_images_path, "fashion_surrogate_gone_wrong.png")) -``` - -### ECCCo - -```{julia} -function pre_process(x; noise::Float32=0.03f0) - ϵ = Float32.(randn(size(x)) * noise) - x += ϵ - return x -end -``` - -```{julia} -# Hyper: -_retrain = true -_regen = true - -# Data: -n_obs = 1000 -counterfactual_data = load_fashion_mnist(n_obs) -counterfactual_data.X = pre_process.(counterfactual_data.X) -counterfactual_data.generative_model = vae -X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) -X = table(permutedims(X)) -x_factual = reshape(pre_process(x_factual, noise=0.0f0), input_dim, 1) -labels = counterfactual_data.output_encoder.labels -input_dim, n_obs = size(counterfactual_data.X) -n_digits = Int(sqrt(input_dim)) -output_dim = length(unique(labels)) -``` - -First, let's create a couple of image classifier architectures: - -```{julia} -# Model parameters: -epochs = 150 -batch_size = minimum([Int(round(n_obs/10)), 128]) -n_hidden = 512 -activation = Flux.swish -builder = MLJFlux.@builder Flux.Chain( - Dense(n_in, n_hidden, activation), - Dense(n_hidden, n_hidden, activation), - Dense(n_hidden, n_hidden, activation), - Dense(n_hidden, n_out), -) -n_ens = 5 # number of models in ensemble -_loss = Flux.Losses.logitcrossentropy # loss function -_finaliser = x -> x # finaliser function -``` - -```{julia} -# JEM parameters: -ð’Ÿx = Uniform(-1.0,1.0) -ð’Ÿy = Categorical(ones(output_dim) ./ output_dim) -sampler = ConditionalSampler( - ð’Ÿx, ð’Ÿy, - input_size=(input_dim,), - batch_size=10, -) -α = [1.0,1.0,5e-2] # penalty strengths -``` - -```{julia} -# Simple MLP: -mlp = NeuralNetworkClassifier( - builder=builder, - epochs=epochs, - batch_size=batch_size, - finaliser=_finaliser, - loss=_loss, -) - -# Deep Ensemble: -mlp_ens = EnsembleModel(model=mlp, n=n_ens) - -# Joint Energy Model: -jem = JointEnergyClassifier( - sampler; - builder=builder, - epochs=epochs, - batch_size=batch_size, - finaliser=_finaliser, - loss=_loss, - jem_training_params=( - α=α,verbosity=10, - ), - sampling_steps=35, -) - -# JEM with adversarial training: -jem_adv = deepcopy(jem) -# jem_adv.adv_training = true - -# Deep Ensemble of Joint Energy Models: -jem_ens = EnsembleModel(model=jem, n=n_ens) - -# Deep Ensemble of Joint Energy Models with adversarial training: -# jem_ens_plus = EnsembleModel(model=jem_adv, n=n_ens) - -# Dictionary of models: -models = Dict( - "MLP" => mlp, - "MLP Ensemble" => mlp_ens, - "JEM" => jem, - "JEM Ensemble" => jem_ens, - # "JEM Ensemble+" => jem_ens_plus, -) - -Serialization.serialize(joinpath(output_path,"fashion_mnist_architectures.jls"), models) -``` - - -```{julia} -# Train models: -function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_name="model") - conf_model = conformal_model(model; method=method, coverage=cov) - mach = machine(conf_model, X, y) - @info "Begin training $mod_name." - fit!(mach) - @info "Finished training $mod_name." - M = ECCCo.ConformalModel(mach.model, mach.fitresult) - return M -end -if _retrain - model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in models) - Serialization.serialize(joinpath(output_path,"fashion_mnist_models.jls"), model_dict) -else - model_dict = Serialization.deserialize(joinpath(output_path,"fashion_mnist_models.jls")) -end -``` - -```{julia} -# Plot generated samples: -n_regen = 500 -if _regen - for (mod_name, mod) in model_dict - K = length(counterfactual_data.y_levels) - input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1)) - ð’Ÿx = Uniform(extrema(counterfactual_data.X)...) - ð’Ÿy = Categorical(ones(K) ./ K) - sampler = ConditionalSampler(ð’Ÿx, ð’Ÿy; input_size=input_size, prob_buffer=0.0) - opt = ImproperSGLD() - f(x) = logits(mod, x) - - _w = 1500 - plts = [] - neach = 10 - for i in 1:10 - x = sampler(f, opt; niter=n_regen, n_samples=neach, y=i) - plts_i = [] - for j in 1:size(x, 2) - xj = x[:,j] - xj = reshape(xj, (n_digits, n_digits)) - plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)] - end - plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10)) - plts = [plts..., plt] - end - plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1), plot_title=mod_name) - savefig(plt, joinpath(output_images_path, "fashion_mnist_generated_$(mod_name).png")) - display(plt) - end -end -``` - -```{julia} -# Evaluate models: - -measure = Dict( - :f1score => multiclass_f1score, - :acc => accuracy, - :precision => multiclass_precision -) -model_performance = DataFrame() -for (mod_name, mod) in model_dict - # Test performance: - test_data = load_fashion_mnist_test() - _perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure))) - _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) - _perf.mod_name .= mod_name - model_performance = vcat(model_performance, _perf) -end -Serialization.serialize(joinpath(output_path,"fashion_mnist_model_performance.jls"), model_performance) -CSV.write(joinpath(output_path, "fashion_mnist_model_performance.csv"), model_performance) -model_performance -``` - -### Different Models - -```{julia} -function _plot_eccco_fashion_mnist( - x::Union{AbstractArray, Int}=x_factual, target::Int=target; - λ=[0.1,0.3,0.3], - temp=0.1, - plt_order = ["MLP", "MLP Ensemble", "JEM", "JEM Ensemble"], - opt = nothing, - rng::Union{Int,AbstractRNG}=1234, - T::Int = 100, -) - - # Setup: - Random.seed!(rng) - if x isa Int - x = reshape(counterfactual_data.X[:,rand(findall(labels.==x))],input_dim,1) - end - - # Generate counterfactuals using ECCCo generator: - eccco_generator = ECCCoGenerator( - λ=λ, - temp=temp, - opt=opt, - ) - - ces = Dict() - for (mod_name, mod) in model_dict - ce = generate_counterfactual( - x, target, counterfactual_data, mod, eccco_generator; - decision_threshold=γ, max_iter=T, - initialization=:identity, - converge_when=:generator_conditions, - ) - ces[mod_name] = ce - end - _plt_order = map(x -> findall(collect(keys(model_dict)) .== x)[1], plt_order) - - # Plot: - p1 = Plots.plot( - convert2image(FashionMNIST, reshape(x,28,28)), - axis=nothing, - size=(img_height, img_height), - title="Factual" - ) - - plts = [] - for (_name,ce) in ces - _x = CounterfactualExplanations.counterfactual(ce) - _phat = target_probs(ce) - _title = "$_name (pÌ‚=$(round(_phat[1]; digits=3)))" - plt = Plots.plot( - convert2image(FashionMNIST, reshape(_x,28,28)), - axis=nothing, - size=(img_height, img_height), - title=_title - ) - plts = [plts..., plt] - end - plts = plts[_plt_order] - plts = [p1, plts...] - plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) - - return plt, eccco_generator -end -``` - -```{julia} -plt, eccco_generator = _plot_eccco_fashion_mnist() -display(plt) -savefig(plt, joinpath(output_images_path, "fashion_mnist_eccco.png")) -``` - -### All digits - -```{julia} -function plot_fashion_mnist( - factual::Int,target::Int; - generator::AbstractGenerator, - model::AbstractFittedModel=model_dict["JEM Ensemble"], - data::CounterfactualData=counterfactual_data, - rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG, - _plot_title::Bool=true, - show_factual::Bool=false, - kwargs..., -) - Random.seed!(rng) - - decision_threshold = !isdefined(kwargs, :decision_threshold) ? 0.9 : decision_threshold - max_iter = !isdefined(kwargs, :max_iter) ? 100 : max_iter - initialization = !isdefined(kwargs, :initialization) ? :identity : initialization - converge_when = !isdefined(kwargs, :converge_when) ? :generator_conditions : converge_when - - x = reshape(data.X[:,rand(findall(predict_label(model, data).==factual))],input_dim,1) - ce = generate_counterfactual( - x, target, data, model, generator; - decision_threshold=decision_threshold, max_iter=max_iter, - initialization=initialization, - converge_when=converge_when, - kwargs... - ) - - _title = _plot_title ? "$(factual) -> $(target)" : "" - - _x = CounterfactualExplanations.counterfactual(ce) - plt = Plots.plot( - convert2image(FashionMNIST, reshape(_x,28,28)), - axis=nothing, - size=(img_height, img_height), - title=_title - ) - if show_factual - plt_factual = Plots.plot( - convert2image(FashionMNIST, reshape(x,28,28)), - axis=nothing, - size=(img_height, img_height), - title="Factual" - ) - plt = Plots.plot(plt_factual, plt; size=(img_height*2,img_height), layout=(1,2)) - end - - return plt -end -``` - -```{julia} -if _regen - function plot_all_digits(rng=123;verbose=true,kwargs...) - plts = [] - for i in 0:9 - for j in 0:9 - @info "Generating counterfactual for $(i) -> $(j)" - plt = plot_fashion_mnist(i,j;kwargs...,rng=rng) - !verbose || display(plt) - plts = [plts..., plt] - end - end - plt = Plots.plot(plts...; size=(img_height*10,img_height*10), layout=(10,10)) - return plt - end - plt = plot_all_digits(generator=eccco_generator) - savefig(plt, joinpath(output_images_path, "fashion_mnist_eccco_all_digits.png")) -end -``` - -## Benchmark - -```{julia} -λ₠= 0.1 -λ₂ = 0.3 -λ₃ = 0.3 -Λ = [λâ‚, λ₂, λ₃] - -# Benchmark generators: -generator_dict = Dict( - "Wachter" => WachterGenerator(λ=λâ‚), - "REVISE" => REVISEGenerator(λ=λâ‚), - "Schut" => GreedyGenerator(), - "ECCCo" => ECCCoGenerator(λ=Λ), -) -``` - -```{julia} -# Measures: -measures = [ - CounterfactualExplanations.distance, - ECCCo.distance_from_energy, - ECCCo.distance_from_targets, - CounterfactualExplanations.Evaluation.validity, - CounterfactualExplanations.Evaluation.redundancy, - ECCCo.set_size_penalty -] - -bmks = [] -for target in sort(unique(labels)) - for factual in sort(unique(labels)) - if factual == target - continue - end - bmk = benchmark( - counterfactual_data; - models=model_dict, - generators=generator_dict, - measure=measures, - suppress_training=true, dataname="FashionMNIST", - n_individuals=5, - target=target, factual=factual, - initialization=:identity, - converge_when=:generator_conditions, - ) - push!(bmks, bmk) - end -end -bmk = reduce(vcat, bmks) - -CSV.write(joinpath(output_path, "fashion_mnist_benchmark.csv"), bmk()) -``` - -```{julia} -df = @chain bmk() begin - @mutate(variable = ifelse.(variable .== "distance_from_energy", "Non-Conformity", variable)) - @mutate(variable = ifelse.(variable .== "distance_from_targets", "Implausibility", variable)) - @mutate(variable = ifelse.(variable .== "distance", "Cost", variable)) - @mutate(variable = ifelse.(variable .== "redundancy", "Redundancy", variable)) - @mutate(variable = ifelse.(variable .== "Validity", "Validity", variable)) -end -plt = AlgebraOfGraphics.data(df) * visual(BoxPlot) * - mapping(:generator, :value, row=:variable, col=:model, color=:generator) -plt = draw( - plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=150, height=120), - facet=(; linkyaxes=:minimal) -) -display(plt) -save(joinpath(output_images_path, "fashion_mnist_benchmark.png"), plt, px_per_unit=5) -``` \ No newline at end of file diff --git a/notebooks/gmsc.qmd b/notebooks/gmsc.qmd index 2e1066fffa761ece91d52f781f768509289d033f..17a7646a465f2f48748d487ddf577080f3411afe 100644 --- a/notebooks/gmsc.qmd +++ b/notebooks/gmsc.qmd @@ -97,7 +97,6 @@ models = Dict( ) ``` - ```{julia} # Train models: function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_name="model") diff --git a/notebooks/linearly_separable.qmd b/notebooks/linearly_separable.qmd index 3f4b94ec68e62dd4ff099bd491eab704a6aeabf2..97c2fab50464383afa96df34d6b69044c305474f 100644 --- a/notebooks/linearly_separable.qmd +++ b/notebooks/linearly_separable.qmd @@ -99,6 +99,25 @@ models = Dict( ) ``` +```{julia} +params = DataFrame( + Dict( + :n_obs => Int.(round(n_obs/10)*10), + :epochs => epochs, + :batch_size => batch_size, + :n_hidden => n_hidden, + :n_layers => length(builder.hidden), + :activation => string(activation), + :n_ens => n_ens, + :lambda => string(α[3]), + :jem_sampling_steps => jem.sampling_steps, + :sgld_batch_size => sampler.batch_size, + :dataname => "Linearly Separable", + ) +) +CSV.write(joinpath(params_path, "linearly_separable.csv"), params) +``` + ```{julia} # Train models: function _train(model, X=X, y=labels; cov=0.95, method=:simple_inductive, mod_name="model") diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 0b8a14b9580a0685d00a340e43e868d4ba8f9ad3..e721f849cd48e85c9244a3fbcbd58f83c4a4bd6b 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -164,11 +164,11 @@ end ```{julia} # Hyper: -_retrain = true -_regen = true +_retrain = false +_regen = false # Data: -n_obs = nothing +n_obs = 10000 counterfactual_data = load_mnist(n_obs) counterfactual_data.X = pre_process.(counterfactual_data.X) counterfactual_data.generative_model = vae @@ -259,7 +259,6 @@ models = Dict( Serialization.serialize(joinpath(output_path,"mnist_architectures.jls"), models) ``` - ```{julia} # Train models: function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_name="model") @@ -279,6 +278,25 @@ else end ``` +```{julia} +params = DataFrame( + Dict( + :n_obs => Int.(round(n_obs/10)*10), + :epochs => epochs, + :batch_size => batch_size, + :n_hidden => n_hidden, + :n_layers => length(model_dict["MLP"].fitresult[1][1])-1, + :activation => string(activation), + :n_ens => n_ens, + :lambda => string(α[3]), + :jem_sampling_steps => jem.sampling_steps, + :sgld_batch_size => sampler.batch_size, + :dataname => "MNIST", + ) +) +CSV.write(joinpath(params_path, "mnist.csv"), params) +``` + ```{julia} # Plot generated samples: n_regen = 500 diff --git a/notebooks/moons.qmd b/notebooks/moons.qmd index 9b4f653685e3f0614bfb5068070c94f5f28ae4c3..97bf10fa74d32ae5c6caa76b5fdd089cfa288f3a 100644 --- a/notebooks/moons.qmd +++ b/notebooks/moons.qmd @@ -96,6 +96,24 @@ models = Dict( ) ``` +```{julia} +params = DataFrame( + Dict( + :n_obs => Int.(round(n_obs/10)*10), + :epochs => epochs, + :batch_size => batch_size, + :n_hidden => n_hidden, + :n_layers => length(builder.hidden), + :activation => string(activation), + :n_ens => n_ens, + :lambda => string(α[3]), + :jem_sampling_steps => jem.sampling_steps, + :sgld_batch_size => sampler.batch_size, + :dataname => "Moons", + ) +) +CSV.write(joinpath(params_path, "moons.csv"), params) +``` ```{julia} # Train models: diff --git a/notebooks/poc.qmd b/notebooks/poc.qmd index 3e3358db70c677e7077cf3fb3daef31f73432cdd..55cf685e9f69fbec6834885005149fea4567b718 100644 --- a/notebooks/poc.qmd +++ b/notebooks/poc.qmd @@ -98,10 +98,12 @@ end #| label: fig-losses #| fig-cap: "Illustration of the smooth size loss and the configurable classification loss." -temp = 0.25 +temp = 0.3 +p0 = Plots.contourf(mach.model, mach.fitresult, permutedims(X), labels; plot_set_size=true, zoom=0, temp=temp) p1 = Plots.contourf(mach.model, mach.fitresult, permutedims(X), labels; plot_set_loss=true, zoom=0, temp=temp) p2 = Plots.contourf(mach.model, mach.fitresult, permutedims(X), labels; plot_classification_loss=true, zoom=0, temp=temp, clim=nothing, loss_matrix=ones(2,2)) -display(Plots.plot(p1, p2, size=(800,320))) +plt = display(Plots.plot(p0, p1, p2, size=(1400,320), layout=(1,3))) +savefig(joinpath(output_images_path, "poc_set_size.png")) ``` ```{julia} diff --git a/notebooks/setup.jl b/notebooks/setup.jl index 3f091c84a0daaf7d24f1d5a2bb5ee854cfa63644..d29f89b1b6d3658456eaef55b2cdb9565c4e709b 100644 --- a/notebooks/setup.jl +++ b/notebooks/setup.jl @@ -41,6 +41,7 @@ setup_notebooks = quote Plots.theme(:wong) Random.seed!(2023) www_path = "www" + params_path = "artifacts/params" output_path = "artifacts/results" output_images_path = "artifacts/results/images" img_height = 300 diff --git a/notebooks/synthetic.qmd b/notebooks/synthetic.qmd deleted file mode 100644 index 337e10e5c42e5d309f05e21ec85b0c4f7d40bdc6..0000000000000000000000000000000000000000 --- a/notebooks/synthetic.qmd +++ /dev/null @@ -1,348 +0,0 @@ -```{julia} -include("$(pwd())/notebooks/setup.jl") -eval(setup_notebooks); -``` - -# Synthetic data - -```{julia} -#| output: false - -# Data: -datasets = Dict( - :linearly_separable => load_linearly_separable(), - :overlapping => load_overlapping(), - :moons => load_moons(), - :circles => load_circles(), - :multi_class => load_multi_class(), -) - -# Hyperparameters: -cvgs = [0.5, 0.75, 0.95] -temps = [0.01, 0.1, 1.0] -Λ = [0.0, 0.1, 1.0, 10.0] -l2_λ = 0.1 - -# Classifiers: -epochs = 250 -link_fun = relu -logreg = NeuralNetworkClassifier(builder=MLJFlux.Linear(σ=link_fun), epochs=epochs) -mlp = NeuralNetworkClassifier(builder=MLJFlux.MLP(hidden=(32,), σ=link_fun), epochs=epochs) -ensmbl = EnsembleModel(model=mlp, n=5) -classifiers = Dict( - # :logreg => logreg, - :mlp => mlp, - # :ensmbl => ensmbl, -) - -# Search parameters: -target = 2 -factual = 1 -max_iter = 50 -gradient_tol = 1e-2 -opt = Descent(0.01) -``` - -```{julia} -#| echo: false - -results = DataFrame() -for (dataname, data) in datasets - - # Data: - X = table(permutedims(data.X)) - y = data.output_encoder.labels - x = select_factual(data,rand(1:size(data.X,2))) - - for (clf_name, clf) in classifiers, cov in cvgs - - # Classifier and coverage: - conf_model = conformal_model(clf; method=:simple_inductive, coverage=cov) - mach = machine(conf_model, X, y) - fit!(mach) - M = ECCCo.ConformalModel(mach.model, mach.fitresult) - - # Set up ECCCo: - factual_label = predict_label(M, data, x)[1] - target_label = data.y_levels[data.y_levels .!= factual_label][1] - - for λ in Λ, temp in temps - - # ECCCo for given classifier, coverage, temperature and λ: - generator = ECCCoGenerator(temp=temp, λ=[l2_λ,λ], opt=opt) - @assert predict_label(M, data, x) != target_label - ce = try - generate_counterfactual( - x, target_label, data, M, generator; - initialization=:identity, - converge_when=:generator_conditions, - gradient_tol=gradient_tol, - max_iter=max_iter, - ) - catch - missing - end - - _results = DataFrame( - dataset = dataname, - classifier = clf_name, - coverage = cov, - temperature = temp, - λ = λ, - ce = ce, - factual = factual_label, - target = target_label, - ) - append!(results, _results) - - end - - end - -end -``` - -```{julia} -#| echo: false - -function plot_ce(results; dataset=:multi_class, classifier=:mlp, λ=0.1, img_height=300, zoom=-0.2) - df_plot = results[results.dataset .== dataset,:] |> - res -> res[res.classifier .== classifier,:] |> - res -> res[res.λ .== λ,:] - plts = map(eachrow(df_plot)) do row - Plots.plot( - row.ce, - title="cov: $(row.coverage), temp: $(row.temperature)", - cbar=false, - zoom=zoom, - legend=false, - ) - end - nrow = length(cvgs) - ncol = length(temps) - _layout = (nrow, ncol) - Plots.plot( - plts..., - size=img_height.*reverse(_layout), layout=_layout, - plot_title="λ: $λ, dataset: $dataset, classifier: $classifier", - ) -end -``` - -```{julia} -#| output: true -#| echo: false - -for dataset in keys(datasets) - Markdown.parse("""### $dataset""") - for classifier in keys(classifiers) - Markdown.parse("""#### $classifier""") - Markdown.parse("""::: {.panel-tabset}""") - for λ in Λ - Markdown.parse("""##### λ: $λ""") - display(plot_ce(results; dataset=dataset, classifier=classifier, λ=λ)) - end - Markdown.parse(""":::""") - end -end -``` - -## Benchmark - -```{julia} -# Benchmark generators: -generators = Dict( - :wachter => GenericGenerator(opt=opt, λ=l2_λ), - :revise => REVISEGenerator(opt=opt, λ=l2_λ), - :greedy => GreedyGenerator(), -) - -# Untrained Models: -models = Dict(Symbol("cov$(Int(100*cov))") => ECCCo.ConformalModel(conformal_model(mlp; method=:simple_inductive, coverage=cov)) for cov in cvgs) - -# Measures: -measures = [ - CounterfactualExplanations.distance, - ECCCo.distance_from_energy, - ECCCo.distance_from_targets, - CounterfactualExplanations.validity, -] -``` - -### Single CE - -```{julia} -#| echo: false - -_temp = 0.01 -results = DataFrame() -for (dataname, data) in datasets - - # Data: - X = table(permutedims(data.X)) - y = data.output_encoder.labels - x = select_factual(data,rand(1:size(data.X,2))) - - for (modelname, M) in deepcopy(models) - - # Model training: - M = train(M, data) - # Set up ECCCo: - factual_label = predict_label(M, data, x)[1] - target_label = data.y_levels[data.y_levels .!= factual_label][1] - - for λ in Λ - - # Generators: - _generators = deepcopy(generators) - _generators[:cce] = ECCCoGenerator(temp=_temp, λ=[l2_λ,λ], opt=opt) - _generators[:energy] = ECCCo.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt) - _generators[:target] = ECCCo.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt) - - for (gen_name, gen) in _generators - - # ECCCo for given models, λ and generator: - @assert predict_label(M, data, x) != target_label - ce = try - generate_counterfactual( - x, target_label, data, M, gen; - initialization=:identity, - converge_when=:generator_conditions, - gradient_tol=gradient_tol, - max_iter=max_iter, - ) - catch - missing - end - - if !ismissing(ce) - eval = DataFrame(evaluate(ce, measure=measures, output_format=:Dict)) - else - eval = DataFrame(Dict(Symbol(fun) => missing for fun in measures)) - end - - _results = DataFrame( - dataset = dataname, - model = modelname, - λ = λ, - generator = gen_name, - ce = ce, - factual = factual_label, - target = target_label, - ) - - _results = crossjoin(_results, eval; makeunique=true) - append!(results, _results) - - end - end - end -end -``` - -```{julia} -#| echo: false - -function plot_benchmark(results; dataset=:multi_class, modelname=:cov95, img_height=300, zoom=-0.2) - df_plot = results[results.dataset .== dataset,:] |> - res -> res[res.model .== modelname,:] - plts = map(eachrow(df_plot)) do row - Plots.plot( - row.ce, - title="λ: $(row.λ), gen: $(row.generator)", - cbar=false, - zoom=zoom, - legend=false, - ) - end - ncol = length(unique(df_plot.generator)) - nrow = length(unique(df_plot.λ)) - _layout = (nrow, ncol) - Plots.plot( - plts..., - size=img_height.*reverse(_layout), layout=_layout, - plot_title="dataset: $dataset, model: $modelname", - ) -end -``` - -```{julia} -#| output: true -#| echo: false - -df = @pivot_longer(results, distance:distance_from_targets) -for dataname ∈ sort(unique(df.dataset)) - Markdown.parse("""### $dataname""") - df_ = df[df.dataset .== dataname, :] - for model in unique(df_.model) - Markdown.parse("""#### model: $model""") - df_plot = df_[df_.model .== model, :] - df_plot = @mutate(df_plot, lambda = string("λ: ", round(λ, digits=2))) - plt = AlgebraOfGraphics.data(df_plot) * visual(BarPlot) * - mapping(:generator, :value, row=:variable, col=:lambda, color=:generator) - plt = draw( - plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=200, height=180), - facet=(; linkyaxes=:minimal) - ) - # plt.figure[0, :] = Label( - # plt.figure, "data: $dataname, model: $model", - # fontsize=20, tellwidth=false - # ) - display(plt) - end -end -``` - -### Full Benchmark - -```{julia} -bmks = [] -for (dataname, dataset) in datasets - for λ in Λ, temp in temps - _generators = deepcopy(generators) - _generators[:cce] = ECCCoGenerator(temp=temp, λ=[l2_λ,λ], opt=opt) - _generators[:energy] = ECCCo.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt) - _generators[:target] = ECCCo.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt) - bmk = benchmark( - dataset; - models=deepcopy(models), - generators=_generators, - measure=measures, - suppress_training=false, dataname=dataname, - n_individuals=5, - initialization=:identity, - ) - bmk.evaluation.λ .= λ - bmk.evaluation.temperature .= temp - push!(bmks, bmk) - end -end -bmk = reduce(vcat, bmks) -``` - -```{julia} -CSV.write(joinpath(output_path, "synthetic_benchmark.csv"), bmk()) -``` - -```{julia} -#| output: true -#| echo: false - -df = bmk() -for dataname ∈ sort(unique(df.dataname)) - Markdown.parse("""### $dataname""") - df_ = df[df.dataname .== dataname, :] - for λ in Λ, temp in temps - Markdown.parse("""#### λ: $λ""") - df_plot = df_[df_.λ .== λ, :] - df_plot = df_plot[df_plot.temperature .== temp, :] - plt = AlgebraOfGraphics.data(df_plot) * visual(BoxPlot) * - mapping(:generator, :value, row=:variable, col=:model, color=:generator) - plt = draw( - plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=200, height=180), - facet=(; linkyaxes=:minimal) - ) - display(plt) - end -end -``` diff --git a/notebooks/tables.Rmd b/notebooks/tables.Rmd index 5ade027770908484d69da2a83c1594ef03722a66..1ffcfe04dc15c9a7256775c29af32819faa80c08 100644 --- a/notebooks/tables.Rmd +++ b/notebooks/tables.Rmd @@ -248,15 +248,15 @@ col_names <- c( "Uncertainty ↓", "Validity ↑" ) +algin_cols <- c(rep('l',3),rep('c',ncol(tab_full)-3)) kbl( tab_full, caption = "All results for all datasets: sample averages +/- one standard deviation over all counterfactuals. Best outcomes are highlighted in bold. Asterisks indicate that the given value is more than one (*) or two (**) standard deviations away from the baseline (Wachter). \\label{tab:results-full} \\newline", - align = "c", col.names=col_names, booktabs = F, escape=F, + align = "c", col.names=col_names, booktabs = T, escape=F, format="latex" ) %>% kable_styling(latex_options = c("scale_down")) %>% kable_paper(full_width = F) %>% - column_spec(1, bold = T) %>% - collapse_rows(columns = 1:3, latex_hline = "custom", valign = "middle", custom_latex_hline = 1:2) %>% + collapse_rows(columns = 1:3, latex_hline = "custom", valign = "top", custom_latex_hline = 1:2) %>% save_kable("paper/contents/table_all.tex") ``` @@ -275,14 +275,14 @@ col_names <- c( "Uncertainty ↓", "Validity ↑" ) +algin_cols <- c(rep('l',3),rep('c',ncol(tab_full)-3)) kbl( tab_full, caption = "All results for all datasets: sample averages +/- one standard deviation over all valid counterfactuals. Best outcomes are highlighted in bold. Asterisks indicate that the given value is more than one (*) or two (**) standard deviations away from the baseline (Wachter). \\label{tab:results-full} \\newline", - align = "c", col.names=col_names, booktabs = F, escape=F, + align = "c", col.names=col_names, booktabs = T, escape=F, format="latex" ) %>% kable_styling(latex_options = c("scale_down")) %>% kable_paper(full_width = F) %>% - column_spec(1, bold = T) %>% - collapse_rows(columns = 1:3, latex_hline = "custom", valign = "middle", custom_latex_hline = 1:2) %>% + collapse_rows(columns = 1:3, latex_hline = "custom", valign = "top", custom_latex_hline = 1:2) %>% save_kable("paper/contents/table_all_valid.tex") ``` \ No newline at end of file diff --git a/paper/contents/table_all.tex b/paper/contents/table_all.tex index f7717d0fea4bc050243df771db912154698107d0..f2f7d5a564f1a27fa4dbefb485d243bfe4911711 100644 --- a/paper/contents/table_all.tex +++ b/paper/contents/table_all.tex @@ -1,147 +1,147 @@ \begin{table} -\caption{All results for all datasets. Standard deviations across samples are shown in parentheses. Best outcomes are highlighted in bold. Asterisks indicate that the given value is more than one (*) or two (**) standard deviations away from the baseline (Wachter). \label{tab:results-full} \newline} +\caption{All results for all datasets: sample averages +/- one standard deviation over all counterfactuals. Best outcomes are highlighted in bold. Asterisks indicate that the given value is more than one (*) or two (**) standard deviations away from the baseline (Wachter). \label{tab:results-full} \newline} \centering \resizebox{\linewidth}{!}{ -\begin{tabular}[t]{>{}c|c|c|c|c|c|c|c|c} -\hline +\begin{tabular}[t]{ccccccccc} +\toprule Model & Data & Generator & Cost ↓ & Unfaithfulness ↓ & Implausibility ↓ & Redundancy ↑ & Uncertainty ↓ & Validity ↑\\ -\hline - & & ECCCo & 0.74 (0.21) & 0.52 (0.36) & 1.22 (0.46) & 0.00 (0.00) & 0.00 (0.00) & \textbf{1.00 (0.00)**}\\ +\midrule + & & ECCCo & 0.74 ± 0.21\hphantom{*}\hphantom{*} & 0.52 ± 0.36\hphantom{*}\hphantom{*} & 1.22 ± 0.46\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{1.00 ± 0.00}**\\ - & & ECCCo (no CP) & 0.72 (0.21) & 0.54 (0.39) & 1.21 (0.46) & 0.00 (0.00) & 0.00 (0.00) & \textbf{1.00 (0.00)**}\\ + & & ECCCo (no CP) & 0.72 ± 0.21\hphantom{*}\hphantom{*} & 0.54 ± 0.39\hphantom{*}\hphantom{*} & 1.21 ± 0.46\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{1.00 ± 0.00}**\\ - & & ECCCo (no EBM) & 0.52 (0.15) & 0.70 (0.33) & 1.30 (0.37) & 0.00 (0.00) & 0.00 (0.00) & \textbf{1.00 (0.00)**}\\ + & & ECCCo (no EBM) & 0.52 ± 0.15\hphantom{*}\hphantom{*} & 0.70 ± 0.33\hphantom{*}\hphantom{*} & 1.30 ± 0.37\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{1.00 ± 0.00}**\\ - & & REVISE & 0.97 (0.34) & \textbf{0.48 (0.16)*} & \textbf{0.95 (0.32)*} & 0.00 (0.00) & 0.00 (0.00) & 0.50 (0.51)\\ + & & REVISE & 0.97 ± 0.34\hphantom{*}\hphantom{*} & \textbf{0.48 ± 0.16}*\hphantom{*} & \textbf{0.95 ± 0.32}*\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.50 ± 0.51\hphantom{*}\hphantom{*}\\ - & & Schut & 1.06 (0.43) & 0.54 (0.43) & 1.28 (0.53) & \textbf{0.26 (0.25)*} & 0.00 (0.00) & \textbf{1.00 (0.00)**}\\ + & & Schut & 1.06 ± 0.43\hphantom{*}\hphantom{*} & 0.54 ± 0.43\hphantom{*}\hphantom{*} & 1.28 ± 0.53\hphantom{*}\hphantom{*} & \textbf{0.26 ± 0.25}*\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{1.00 ± 0.00}**\\ - & \multirow{-6}{*}{\centering\arraybackslash JEM} & Wachter & \textbf{0.44 (0.16)} & 0.68 (0.34) & 1.33 (0.32) & 0.00 (0.00) & 0.00 (0.00) & 0.98 (0.14)\\ -\cline{2-9} - & & ECCCo & 0.67 (0.19) & 0.65 (0.53) & 1.17 (0.41) & 0.00 (0.00) & 0.09 (0.19)** & \textbf{1.00 (0.00)}\\ + & \multirow[t]{-6}{*}{\centering\arraybackslash JEM} & Wachter & \textbf{0.44 ± 0.16}\hphantom{*}\hphantom{*} & 0.68 ± 0.34\hphantom{*}\hphantom{*} & 1.33 ± 0.32\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.98 ± 0.14\hphantom{*}\hphantom{*}\\ +\cmidrule{2-9} + & & ECCCo & 0.67 ± 0.19\hphantom{*}\hphantom{*} & 0.65 ± 0.53\hphantom{*}\hphantom{*} & 1.17 ± 0.41\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.09 ± 0.19** & \textbf{1.00 ± 0.00}\hphantom{*}\hphantom{*}\\ - & & ECCCo (no CP) & 0.71 (0.16) & \textbf{0.49 (0.35)} & 1.19 (0.44) & 0.00 (0.00) & 0.05 (0.16)** & \textbf{1.00 (0.00)}\\ + & & ECCCo (no CP) & 0.71 ± 0.16\hphantom{*}\hphantom{*} & \textbf{0.49 ± 0.35}\hphantom{*}\hphantom{*} & 1.19 ± 0.44\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.05 ± 0.16** & \textbf{1.00 ± 0.00}\hphantom{*}\hphantom{*}\\ - & & ECCCo (no EBM) & 0.45 (0.11) & 0.84 (0.51) & 1.23 (0.31) & 0.00 (0.00) & 0.15 (0.23)* & \textbf{1.00 (0.00)}\\ + & & ECCCo (no EBM) & 0.45 ± 0.11\hphantom{*}\hphantom{*} & 0.84 ± 0.51\hphantom{*}\hphantom{*} & 1.23 ± 0.31\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.15 ± 0.23*\hphantom{*} & \textbf{1.00 ± 0.00}\hphantom{*}\hphantom{*}\\ - & & REVISE & 0.96 (0.31) & 0.58 (0.52) & \textbf{0.95 (0.32)} & 0.00 (0.00) & \textbf{0.00 (0.00)**} & 0.50 (0.51)\\ + & & REVISE & 0.96 ± 0.31\hphantom{*}\hphantom{*} & 0.58 ± 0.52\hphantom{*}\hphantom{*} & \textbf{0.95 ± 0.32}\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}** & 0.50 ± 0.51\hphantom{*}\hphantom{*}\\ - & & Schut & 0.57 (0.11) & 0.58 (0.37) & 1.23 (0.43) & \textbf{0.43 (0.18)**} & \textbf{0.00 (0.00)**} & \textbf{1.00 (0.00)}\\ + & & Schut & 0.57 ± 0.11\hphantom{*}\hphantom{*} & 0.58 ± 0.37\hphantom{*}\hphantom{*} & 1.23 ± 0.43\hphantom{*}\hphantom{*} & \textbf{0.43 ± 0.18}** & \textbf{0.00 ± 0.00}** & \textbf{1.00 ± 0.00}\hphantom{*}\hphantom{*}\\ -\multirow{-12}{*}{\centering\arraybackslash \textbf{Circles}} & \multirow{-6}{*}{\centering\arraybackslash MLP} & Wachter & \textbf{0.40 (0.09)} & 0.83 (0.50) & 1.24 (0.29) & 0.00 (0.00) & 0.53 (0.01) & \textbf{1.00 (0.00)}\\ -\cline{1-9} - & & ECCCo & 19.32 (4.51)** & \textbf{79.45 (11.98)**} & 22.05 (10.58)** & 0.00 (0.00) & \textbf{0.07 (0.03)} & 0.85 (0.37)\\ +\multirow[t]{-12}{*}{\centering\arraybackslash Circles} & \multirow[t]{-6}{*}{\centering\arraybackslash MLP} & Wachter & \textbf{0.40 ± 0.09}\hphantom{*}\hphantom{*} & 0.83 ± 0.50\hphantom{*}\hphantom{*} & 1.24 ± 0.29\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.53 ± 0.01\hphantom{*}\hphantom{*} & \textbf{1.00 ± 0.00}\hphantom{*}\hphantom{*}\\ +\cmidrule{1-9} + & & ECCCo & 19.32 ± 4.51** & \textbf{79.45 ± 11.98}** & 22.05 ± 10.58** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.07 ± 0.03}\hphantom{*}\hphantom{*} & 0.85 ± 0.37\hphantom{*}\hphantom{*}\\ - & & REVISE & 3.66 (2.25)** & 187.06 (31.29) & \textbf{7.06 (7.73)**} & 0.00 (0.00) & 0.37 (0.21) & \textbf{0.95 (0.22)}\\ + & & REVISE & 3.66 ± 2.25** & 187.06 ± 31.29\hphantom{*}\hphantom{*} & \textbf{7.06 ± 7.73}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.37 ± 0.21\hphantom{*}\hphantom{*} & \textbf{0.95 ± 0.22}\hphantom{*}\hphantom{*}\\ - & & Schut & \textbf{1.56 (1.75)**} & 185.64 (37.42) & 8.47 (8.68)** & \textbf{0.69 (0.19)**} & 0.08 (0.02) & \textbf{0.95 (0.22)}\\ + & & Schut & \textbf{1.56 ± 1.75}** & 185.64 ± 37.42\hphantom{*}\hphantom{*} & 8.47 ± 8.68** & \textbf{0.69 ± 0.19}** & 0.08 ± 0.02\hphantom{*}\hphantom{*} & \textbf{0.95 ± 0.22}\hphantom{*}\hphantom{*}\\ - & \multirow{-4}{*}{\centering\arraybackslash JEM} & Wachter & 65.38 (61.49) & 186.20 (42.26) & 70.79 (58.72) & 0.00 (0.00) & 0.08 (0.02) & \textbf{0.95 (0.22)}\\ -\cline{2-9} - & & ECCCo & 16.90 (4.81)** & \textbf{79.65 (11.83)**} & 17.81 (5.44)** & 0.00 (0.00) & 0.17 (0.19) & 1.00 (0.00)\\ + & \multirow[t]{-4}{*}{\centering\arraybackslash JEM} & Wachter & 65.38 ± 61.49\hphantom{*}\hphantom{*} & 186.20 ± 42.26\hphantom{*}\hphantom{*} & 70.79 ± 58.72\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.08 ± 0.02\hphantom{*}\hphantom{*} & \textbf{0.95 ± 0.22}\hphantom{*}\hphantom{*}\\ +\cmidrule{2-9} + & & ECCCo & 16.90 ± 4.81** & \textbf{79.65 ± 11.83}** & 17.81 ± 5.44** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.17 ± 0.19\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ - & & REVISE & 2.97 (0.95)** & 204.14 (36.13) & \textbf{4.90 (0.95)**} & 0.00 (0.00) & 0.35 (0.18) & 1.00 (0.00)\\ + & & REVISE & 2.97 ± 0.95** & 204.14 ± 36.13\hphantom{*}\hphantom{*} & \textbf{4.90 ± 0.95}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.35 ± 0.18\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ - & & Schut & \textbf{1.23 (0.30)**} & 186.24 (36.18) & 6.35 (1.22)** & \textbf{0.66 (0.06)**} & 0.13 (0.06) & 1.00 (0.00)\\ + & & Schut & \textbf{1.23 ± 0.30}** & 186.24 ± 36.18\hphantom{*}\hphantom{*} & 6.35 ± 1.22** & \textbf{0.66 ± 0.06}** & 0.13 ± 0.06\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ - & \multirow{-4}{*}{\centering\arraybackslash JEM Ensemble} & Wachter & 57.72 (49.41) & 184.05 (23.11) & 61.40 (48.29) & 0.01 (0.02) & \textbf{0.11 (0.02)} & 1.00 (0.00)\\ -\cline{2-9} - & & ECCCo & 22.47 (6.06)** & \textbf{79.84 (15.97)**} & 26.78 (11.64)** & 0.00 (0.00) & \textbf{0.11 (0.05)} & 0.85 (0.37)\\ + & \multirow[t]{-4}{*}{\centering\arraybackslash JEM Ensemble} & Wachter & 57.72 ± 49.41\hphantom{*}\hphantom{*} & 184.05 ± 23.11\hphantom{*}\hphantom{*} & 61.40 ± 48.29\hphantom{*}\hphantom{*} & 0.01 ± 0.02\hphantom{*}\hphantom{*} & \textbf{0.11 ± 0.02}\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ +\cmidrule{2-9} + & & ECCCo & 22.47 ± 6.06** & \textbf{79.84 ± 15.97}** & 26.78 ± 11.64** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.11 ± 0.05}\hphantom{*}\hphantom{*} & 0.85 ± 0.37\hphantom{*}\hphantom{*}\\ - & & REVISE & 7.29 (12.81)** & 180.18 (30.75) & \textbf{5.05 (1.05)**} & 0.00 (0.00) & 0.31 (0.14) & \textbf{1.00 (0.00)**}\\ + & & REVISE & 7.29 ± 12.81** & 180.18 ± 30.75\hphantom{*}\hphantom{*} & \textbf{5.05 ± 1.05}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.31 ± 0.14\hphantom{*}\hphantom{*} & \textbf{1.00 ± 0.00}**\\ - & & Schut & \textbf{2.67 (2.71)**} & 196.86 (45.07) & 11.16 (12.19)** & \textbf{0.67 (0.25)**} & 0.12 (0.04) & 0.90 (0.31)\\ + & & Schut & \textbf{2.67 ± 2.71}** & 196.86 ± 45.07\hphantom{*}\hphantom{*} & 11.16 ± 12.19** & \textbf{0.67 ± 0.25}** & 0.12 ± 0.04\hphantom{*}\hphantom{*} & 0.90 ± 0.31\hphantom{*}\hphantom{*}\\ - & \multirow{-4}{*}{\centering\arraybackslash MLP} & Wachter & 81.98 (54.19) & 196.51 (31.36) & 81.50 (54.31) & 0.00 (0.00) & 0.12 (0.04) & 0.90 (0.31)\\ -\cline{2-9} - & & ECCCo & 22.45 (8.45)** & \textbf{76.32 (14.56)**} & 22.99 (8.31)** & 0.00 (0.00) & 0.13 (0.00) & \textbf{1.00 (0.00)**}\\ + & \multirow[t]{-4}{*}{\centering\arraybackslash MLP} & Wachter & 81.98 ± 54.19\hphantom{*}\hphantom{*} & 196.51 ± 31.36\hphantom{*}\hphantom{*} & 81.50 ± 54.31\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.12 ± 0.04\hphantom{*}\hphantom{*} & 0.90 ± 0.31\hphantom{*}\hphantom{*}\\ +\cmidrule{2-9} + & & ECCCo & 22.45 ± 8.45** & \textbf{76.32 ± 14.56}** & 22.99 ± 8.31** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.13 ± 0.00\hphantom{*}\hphantom{*} & \textbf{1.00 ± 0.00}**\\ - & & REVISE & 3.16 (0.91)** & 184.04 (29.13)* & \textbf{5.25 (1.31)**} & 0.00 (0.00) & 0.27 (0.11) & \textbf{1.00 (0.00)**}\\ + & & REVISE & 3.16 ± 0.91** & 184.04 ± 29.13*\hphantom{*} & \textbf{5.25 ± 1.31}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.27 ± 0.11\hphantom{*}\hphantom{*} & \textbf{1.00 ± 0.00}**\\ - & & Schut & \textbf{0.61 (0.24)**} & 214.74 (34.33) & 6.18 (1.17)** & \textbf{0.89 (0.03)**} & 0.13 (0.00) & \textbf{1.00 (0.00)**}\\ + & & Schut & \textbf{0.61 ± 0.24}** & 214.74 ± 34.33\hphantom{*}\hphantom{*} & 6.18 ± 1.17** & \textbf{0.89 ± 0.03}** & 0.13 ± 0.00\hphantom{*}\hphantom{*} & \textbf{1.00 ± 0.00}**\\ -\multirow{-16}{*}{\centering\arraybackslash \textbf{GMSC}} & \multirow{-4}{*}{\centering\arraybackslash MLP Ensemble} & Wachter & 60.72 (53.52) & 216.50 (41.31) & 64.04 (52.79) & 0.00 (0.00) & \textbf{0.06 (0.06)} & 0.50 (0.51)\\ -\cline{1-9} - & & ECCCo & 0.75 (0.17) & \textbf{0.03 (0.06)**} & \textbf{0.20 (0.08)**} & 0.00 (0.00) & \textbf{0.00 (0.00)} & \textbf{1.00 (0.00)}\\ +\multirow[t]{-16}{*}{\centering\arraybackslash GMSC} & \multirow[t]{-4}{*}{\centering\arraybackslash MLP Ensemble} & Wachter & 60.72 ± 53.52\hphantom{*}\hphantom{*} & 216.50 ± 41.31\hphantom{*}\hphantom{*} & 64.04 ± 52.79\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.06 ± 0.06}\hphantom{*}\hphantom{*} & 0.50 ± 0.51\hphantom{*}\hphantom{*}\\ +\cmidrule{1-9} + & & ECCCo & 0.75 ± 0.17\hphantom{*}\hphantom{*} & \textbf{0.03 ± 0.06}** & \textbf{0.20 ± 0.08}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}\hphantom{*}\hphantom{*} & \textbf{1.00 ± 0.00}\hphantom{*}\hphantom{*}\\ - & & ECCCo (no CP) & 0.75 (0.17) & 0.03 (0.06)** & 0.20 (0.08)** & 0.00 (0.00) & \textbf{0.00 (0.00)} & \textbf{1.00 (0.00)}\\ + & & ECCCo (no CP) & 0.75 ± 0.17\hphantom{*}\hphantom{*} & 0.03 ± 0.06** & 0.20 ± 0.08** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}\hphantom{*}\hphantom{*} & \textbf{1.00 ± 0.00}\hphantom{*}\hphantom{*}\\ - & & ECCCo (no EBM) & 0.70 (0.16) & 0.16 (0.11) & 0.34 (0.19) & 0.00 (0.00) & \textbf{0.00 (0.00)} & \textbf{1.00 (0.00)}\\ + & & ECCCo (no EBM) & 0.70 ± 0.16\hphantom{*}\hphantom{*} & 0.16 ± 0.11\hphantom{*}\hphantom{*} & 0.34 ± 0.19\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}\hphantom{*}\hphantom{*} & \textbf{1.00 ± 0.00}\hphantom{*}\hphantom{*}\\ - & & REVISE & \textbf{0.41 (0.15)} & 0.19 (0.03) & 0.41 (0.01)** & 0.00 (0.00) & 0.36 (0.36) & 0.50 (0.51)\\ + & & REVISE & \textbf{0.41 ± 0.15}\hphantom{*}\hphantom{*} & 0.19 ± 0.03\hphantom{*}\hphantom{*} & 0.41 ± 0.01** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.36 ± 0.36\hphantom{*}\hphantom{*} & 0.50 ± 0.51\hphantom{*}\hphantom{*}\\ - & & Schut & 1.15 (0.35) & 0.39 (0.07) & 0.73 (0.17) & \textbf{0.25 (0.25)} & \textbf{0.00 (0.00)} & \textbf{1.00 (0.00)}\\ + & & Schut & 1.15 ± 0.35\hphantom{*}\hphantom{*} & 0.39 ± 0.07\hphantom{*}\hphantom{*} & 0.73 ± 0.17\hphantom{*}\hphantom{*} & \textbf{0.25 ± 0.25}\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}\hphantom{*}\hphantom{*} & \textbf{1.00 ± 0.00}\hphantom{*}\hphantom{*}\\ - & \multirow{-6}{*}{\centering\arraybackslash JEM} & Wachter & 0.50 (0.13) & 0.18 (0.10) & 0.44 (0.17) & 0.00 (0.00) & \textbf{0.00 (0.00)} & \textbf{1.00 (0.00)}\\ -\cline{2-9} - & & ECCCo & 0.95 (0.16) & \textbf{0.29 (0.05)**} & 0.23 (0.06)** & 0.00 (0.00) & \textbf{0.00 (0.00)**} & \textbf{1.00 (0.00)}\\ + & \multirow[t]{-6}{*}{\centering\arraybackslash JEM} & Wachter & 0.50 ± 0.13\hphantom{*}\hphantom{*} & 0.18 ± 0.10\hphantom{*}\hphantom{*} & 0.44 ± 0.17\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}\hphantom{*}\hphantom{*} & \textbf{1.00 ± 0.00}\hphantom{*}\hphantom{*}\\ +\cmidrule{2-9} + & & ECCCo & 0.95 ± 0.16\hphantom{*}\hphantom{*} & \textbf{0.29 ± 0.05}** & 0.23 ± 0.06** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}** & \textbf{1.00 ± 0.00}\hphantom{*}\hphantom{*}\\ - & & ECCCo (no CP) & 0.94 (0.16) & 0.29 (0.05)** & \textbf{0.23 (0.07)**} & 0.00 (0.00) & \textbf{0.00 (0.00)**} & \textbf{1.00 (0.00)}\\ + & & ECCCo (no CP) & 0.94 ± 0.16\hphantom{*}\hphantom{*} & 0.29 ± 0.05** & \textbf{0.23 ± 0.07}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}** & \textbf{1.00 ± 0.00}\hphantom{*}\hphantom{*}\\ - & & ECCCo (no EBM) & 0.60 (0.15) & 0.46 (0.05) & 0.28 (0.04)** & 0.00 (0.00) & 0.02 (0.10)** & \textbf{1.00 (0.00)}\\ + & & ECCCo (no EBM) & 0.60 ± 0.15\hphantom{*}\hphantom{*} & 0.46 ± 0.05\hphantom{*}\hphantom{*} & 0.28 ± 0.04** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.02 ± 0.10** & \textbf{1.00 ± 0.00}\hphantom{*}\hphantom{*}\\ - & & REVISE & \textbf{0.42 (0.14)} & 0.56 (0.05) & 0.41 (0.01) & 0.00 (0.00) & 0.47 (0.50) & 0.48 (0.50)\\ + & & REVISE & \textbf{0.42 ± 0.14}\hphantom{*}\hphantom{*} & 0.56 ± 0.05\hphantom{*}\hphantom{*} & 0.41 ± 0.01\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.47 ± 0.50\hphantom{*}\hphantom{*} & 0.48 ± 0.50\hphantom{*}\hphantom{*}\\ - & & Schut & 0.77 (0.17) & 0.43 (0.06)* & 0.47 (0.36) & \textbf{0.20 (0.25)} & \textbf{0.00 (0.00)**} & \textbf{1.00 (0.00)}\\ + & & Schut & 0.77 ± 0.17\hphantom{*}\hphantom{*} & 0.43 ± 0.06*\hphantom{*} & 0.47 ± 0.36\hphantom{*}\hphantom{*} & \textbf{0.20 ± 0.25}\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}** & \textbf{1.00 ± 0.00}\hphantom{*}\hphantom{*}\\ -\multirow{-12}{*}{\centering\arraybackslash \textbf{Linearly Separable}} & \multirow{-6}{*}{\centering\arraybackslash MLP} & Wachter & 0.51 (0.15) & 0.51 (0.04) & 0.40 (0.08) & 0.00 (0.00) & 0.59 (0.02) & \textbf{1.00 (0.00)}\\ -\cline{1-9} - & & ECCCo & 334.61 (46.37) & \textbf{19.28 (5.01)**} & 314.76 (32.36)* & 0.00 (0.00) & 4.43 (0.56) & \textbf{0.98 (0.12)}\\ +\multirow[t]{-12}{*}{\centering\arraybackslash Linearly Separable} & \multirow[t]{-6}{*}{\centering\arraybackslash MLP} & Wachter & 0.51 ± 0.15\hphantom{*}\hphantom{*} & 0.51 ± 0.04\hphantom{*}\hphantom{*} & 0.40 ± 0.08\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.59 ± 0.02\hphantom{*}\hphantom{*} & \textbf{1.00 ± 0.00}\hphantom{*}\hphantom{*}\\ +\cmidrule{1-9} + & & ECCCo & 334.61 ± 46.37\hphantom{*}\hphantom{*} & \textbf{19.28 ± 5.01}** & 314.76 ± 32.36*\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 4.43 ± 0.56\hphantom{*}\hphantom{*} & \textbf{0.98 ± 0.12}\hphantom{*}\hphantom{*}\\ - & & REVISE & 170.68 (63.26) & 188.70 (26.18)* & \textbf{255.26 (41.50)**} & 0.00 (0.00) & 4.39 (0.91) & 0.96 (0.20)\\ + & & REVISE & 170.68 ± 63.26\hphantom{*}\hphantom{*} & 188.70 ± 26.18*\hphantom{*} & \textbf{255.26 ± 41.50}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 4.39 ± 0.91\hphantom{*}\hphantom{*} & 0.96 ± 0.20\hphantom{*}\hphantom{*}\\ - & & Schut & \textbf{9.44 (1.60)**} & 211.00 (27.21) & 286.61 (39.85)* & \textbf{0.99 (0.00)**} & \textbf{1.08 (1.95)*} & 0.24 (0.43)\\ + & & Schut & \textbf{9.44 ± 1.60}** & 211.00 ± 27.21\hphantom{*}\hphantom{*} & 286.61 ± 39.85*\hphantom{*} & \textbf{0.99 ± 0.00}** & \textbf{1.08 ± 1.95}*\hphantom{*} & 0.24 ± 0.43\hphantom{*}\hphantom{*}\\ - & \multirow{-4}{*}{\centering\arraybackslash JEM} & Wachter & 128.36 (14.95) & 222.90 (26.56) & 361.88 (39.74) & 0.00 (0.00) & 4.37 (0.98) & 0.95 (0.21)\\ -\cline{2-9} - & & ECCCo & 342.64 (41.14) & \textbf{15.99 (3.06)**} & 294.72 (30.75)** & 0.00 (0.00) & 2.07 (0.06)** & \textbf{1.00 (0.00)**}\\ + & \multirow[t]{-4}{*}{\centering\arraybackslash JEM} & Wachter & 128.36 ± 14.95\hphantom{*}\hphantom{*} & 222.90 ± 26.56\hphantom{*}\hphantom{*} & 361.88 ± 39.74\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 4.37 ± 0.98\hphantom{*}\hphantom{*} & 0.95 ± 0.21\hphantom{*}\hphantom{*}\\ +\cmidrule{2-9} + & & ECCCo & 342.64 ± 41.14\hphantom{*}\hphantom{*} & \textbf{15.99 ± 3.06}** & 294.72 ± 30.75** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 2.07 ± 0.06** & \textbf{1.00 ± 0.00}**\\ - & & REVISE & 170.21 (58.02) & 173.59 (20.65)** & \textbf{246.32 (37.46)**} & 0.00 (0.00) & 2.56 (0.83) & 0.93 (0.26)\\ + & & REVISE & 170.21 ± 58.02\hphantom{*}\hphantom{*} & 173.59 ± 20.65** & \textbf{246.32 ± 37.46}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 2.56 ± 0.83\hphantom{*}\hphantom{*} & 0.93 ± 0.26\hphantom{*}\hphantom{*}\\ - & & Schut & \textbf{9.78 (1.02)**} & 205.33 (24.07) & 287.39 (39.33)* & \textbf{0.99 (0.00)**} & \textbf{0.32 (0.94)**} & 0.11 (0.31)\\ + & & Schut & \textbf{9.78 ± 1.02}** & 205.33 ± 24.07\hphantom{*}\hphantom{*} & 287.39 ± 39.33*\hphantom{*} & \textbf{0.99 ± 0.00}** & \textbf{0.32 ± 0.94}** & 0.11 ± 0.31\hphantom{*}\hphantom{*}\\ - & \multirow{-4}{*}{\centering\arraybackslash JEM Ensemble} & Wachter & 135.07 (16.79) & 217.67 (23.78) & 363.23 (39.24) & 0.00 (0.00) & 2.93 (0.77) & 0.94 (0.23)\\ -\cline{2-9} - & & ECCCo & 605.17 (44.78) & \textbf{41.95 (6.50)**} & 591.58 (36.24) & 0.00 (0.00) & 0.57 (0.00)** & \textbf{1.00 (0.00)**}\\ + & \multirow[t]{-4}{*}{\centering\arraybackslash JEM Ensemble} & Wachter & 135.07 ± 16.79\hphantom{*}\hphantom{*} & 217.67 ± 23.78\hphantom{*}\hphantom{*} & 363.23 ± 39.24\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 2.93 ± 0.77\hphantom{*}\hphantom{*} & 0.94 ± 0.23\hphantom{*}\hphantom{*}\\ +\cmidrule{2-9} + & & ECCCo & 605.17 ± 44.78\hphantom{*}\hphantom{*} & \textbf{41.95 ± 6.50}** & 591.58 ± 36.24\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.57 ± 0.00** & \textbf{1.00 ± 0.00}**\\ - & & REVISE & 146.61 (36.96) & 365.82 (15.35)* & \textbf{249.49 (41.55)**} & 0.00 (0.00) & 0.62 (0.30) & 0.87 (0.34)\\ + & & REVISE & 146.61 ± 36.96\hphantom{*}\hphantom{*} & 365.82 ± 15.35*\hphantom{*} & \textbf{249.49 ± 41.55}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.62 ± 0.30\hphantom{*}\hphantom{*} & 0.87 ± 0.34\hphantom{*}\hphantom{*}\\ - & & Schut & \textbf{9.95 (0.37)**} & 382.44 (17.81) & 285.98 (42.48)* & \textbf{0.99 (0.00)**} & \textbf{0.05 (0.19)**} & 0.06 (0.24)\\ + & & Schut & \textbf{9.95 ± 0.37}** & 382.44 ± 17.81\hphantom{*}\hphantom{*} & 285.98 ± 42.48*\hphantom{*} & \textbf{0.99 ± 0.00}** & \textbf{0.05 ± 0.19}** & 0.06 ± 0.24\hphantom{*}\hphantom{*}\\ - & \multirow{-4}{*}{\centering\arraybackslash MLP} & Wachter & 136.08 (16.09) & 386.05 (16.60) & 361.83 (42.18) & 0.00 (0.00) & 0.68 (0.36) & 0.84 (0.36)\\ -\cline{2-9} - & & ECCCo & 525.87 (34.00) & \textbf{31.43 (3.91)**} & 490.88 (27.19) & 0.00 (0.00) & 0.29 (0.00)** & \textbf{1.00 (0.00)**}\\ + & \multirow[t]{-4}{*}{\centering\arraybackslash MLP} & Wachter & 136.08 ± 16.09\hphantom{*}\hphantom{*} & 386.05 ± 16.60\hphantom{*}\hphantom{*} & 361.83 ± 42.18\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.68 ± 0.36\hphantom{*}\hphantom{*} & 0.84 ± 0.36\hphantom{*}\hphantom{*}\\ +\cmidrule{2-9} + & & ECCCo & 525.87 ± 34.00\hphantom{*}\hphantom{*} & \textbf{31.43 ± 3.91}** & 490.88 ± 27.19\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.29 ± 0.00** & \textbf{1.00 ± 0.00}**\\ - & & REVISE & 146.60 (35.64) & 337.74 (11.89)* & \textbf{247.67 (38.36)**} & 0.00 (0.00) & 0.39 (0.22) & 0.85 (0.36)\\ + & & REVISE & 146.60 ± 35.64\hphantom{*}\hphantom{*} & 337.74 ± 11.89*\hphantom{*} & \textbf{247.67 ± 38.36}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.39 ± 0.22\hphantom{*}\hphantom{*} & 0.85 ± 0.36\hphantom{*}\hphantom{*}\\ - & & Schut & \textbf{9.98 (0.25)**} & 359.54 (14.52) & 283.99 (41.08)* & \textbf{0.99 (0.00)**} & \textbf{0.03 (0.14)**} & 0.06 (0.24)\\ + & & Schut & \textbf{9.98 ± 0.25}** & 359.54 ± 14.52\hphantom{*}\hphantom{*} & 283.99 ± 41.08*\hphantom{*} & \textbf{0.99 ± 0.00}** & \textbf{0.03 ± 0.14}** & 0.06 ± 0.24\hphantom{*}\hphantom{*}\\ -\multirow{-16}{*}{\centering\arraybackslash \textbf{MNIST}} & \multirow{-4}{*}{\centering\arraybackslash MLP Ensemble} & Wachter & 137.53 (18.95) & 360.79 (14.39) & 357.73 (42.55) & 0.00 (0.00) & 0.47 (0.64) & 0.80 (0.40)\\ -\cline{1-9} - & & ECCCo & 1.56 (0.44) & \textbf{0.31 (0.30)*} & \textbf{1.20 (0.15)**} & 0.00 (0.00) & \textbf{0.00 (0.00)**} & \textbf{1.00 (0.00)**}\\ +\multirow[t]{-16}{*}{\centering\arraybackslash MNIST} & \multirow[t]{-4}{*}{\centering\arraybackslash MLP Ensemble} & Wachter & 137.53 ± 18.95\hphantom{*}\hphantom{*} & 360.79 ± 14.39\hphantom{*}\hphantom{*} & 357.73 ± 42.55\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.47 ± 0.64\hphantom{*}\hphantom{*} & 0.80 ± 0.40\hphantom{*}\hphantom{*}\\ +\cmidrule{1-9} + & & ECCCo & 1.56 ± 0.44\hphantom{*}\hphantom{*} & \textbf{0.31 ± 0.30}*\hphantom{*} & \textbf{1.20 ± 0.15}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}** & \textbf{1.00 ± 0.00}**\\ - & & ECCCo (no CP) & 1.56 (0.46) & 0.37 (0.30)* & 1.21 (0.17)** & 0.00 (0.00) & \textbf{0.00 (0.00)**} & \textbf{1.00 (0.00)**}\\ + & & ECCCo (no CP) & 1.56 ± 0.46\hphantom{*}\hphantom{*} & 0.37 ± 0.30*\hphantom{*} & 1.21 ± 0.17** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}** & \textbf{1.00 ± 0.00}**\\ - & & ECCCo (no EBM) & 0.80 (0.25) & 0.91 (0.32) & 1.71 (0.25) & 0.00 (0.00) & \textbf{0.00 (0.00)**} & \textbf{1.00 (0.00)**}\\ + & & ECCCo (no EBM) & 0.80 ± 0.25\hphantom{*}\hphantom{*} & 0.91 ± 0.32\hphantom{*}\hphantom{*} & 1.71 ± 0.25\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}** & \textbf{1.00 ± 0.00}**\\ - & & REVISE & 1.04 (0.43) & 0.78 (0.23) & 1.57 (0.26) & 0.00 (0.00) & \textbf{0.00 (0.00)**} & \textbf{1.00 (0.00)**}\\ + & & REVISE & 1.04 ± 0.43\hphantom{*}\hphantom{*} & 0.78 ± 0.23\hphantom{*}\hphantom{*} & 1.57 ± 0.26\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}** & \textbf{1.00 ± 0.00}**\\ - & & Schut & 1.12 (0.31) & 0.67 (0.27) & 1.50 (0.22)* & \textbf{0.08 (0.19)} & \textbf{0.00 (0.00)**} & 0.98 (0.14)\\ + & & Schut & 1.12 ± 0.31\hphantom{*}\hphantom{*} & 0.67 ± 0.27\hphantom{*}\hphantom{*} & 1.50 ± 0.22*\hphantom{*} & \textbf{0.08 ± 0.19}\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}** & 0.98 ± 0.14\hphantom{*}\hphantom{*}\\ - & \multirow{-6}{*}{\centering\arraybackslash JEM} & Wachter & \textbf{0.72 (0.24)} & 0.80 (0.27) & 1.78 (0.24) & 0.00 (0.00) & 0.02 (0.10) & 0.98 (0.14)\\ -\cline{2-9} - & & ECCCo & 2.18 (1.05) & 0.80 (0.62) & 1.69 (0.40) & 0.00 (0.00) & 0.15 (0.24)* & \textbf{1.00 (0.00)}\\ + & \multirow[t]{-6}{*}{\centering\arraybackslash JEM} & Wachter & \textbf{0.72 ± 0.24}\hphantom{*}\hphantom{*} & 0.80 ± 0.27\hphantom{*}\hphantom{*} & 1.78 ± 0.24\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.02 ± 0.10\hphantom{*}\hphantom{*} & 0.98 ± 0.14\hphantom{*}\hphantom{*}\\ +\cmidrule{2-9} + & & ECCCo & 2.18 ± 1.05\hphantom{*}\hphantom{*} & 0.80 ± 0.62\hphantom{*}\hphantom{*} & 1.69 ± 0.40\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.15 ± 0.24*\hphantom{*} & \textbf{1.00 ± 0.00}\hphantom{*}\hphantom{*}\\ - & & ECCCo (no CP) & 2.07 (1.15) & \textbf{0.79 (0.62)} & 1.68 (0.42) & 0.00 (0.00) & 0.15 (0.24)* & \textbf{1.00 (0.00)}\\ + & & ECCCo (no CP) & 2.07 ± 1.15\hphantom{*}\hphantom{*} & \textbf{0.79 ± 0.62}\hphantom{*}\hphantom{*} & 1.68 ± 0.42\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.15 ± 0.24*\hphantom{*} & \textbf{1.00 ± 0.00}\hphantom{*}\hphantom{*}\\ - & & ECCCo (no EBM) & 1.25 (0.92) & 1.34 (0.47) & 1.68 (0.47) & 0.00 (0.00) & 0.43 (0.18) & \textbf{1.00 (0.00)}\\ + & & ECCCo (no EBM) & 1.25 ± 0.92\hphantom{*}\hphantom{*} & 1.34 ± 0.47\hphantom{*}\hphantom{*} & 1.68 ± 0.47\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.43 ± 0.18\hphantom{*}\hphantom{*} & \textbf{1.00 ± 0.00}\hphantom{*}\hphantom{*}\\ - & & REVISE & 0.79 (0.19)* & 1.45 (0.44) & \textbf{1.64 (0.31)} & 0.00 (0.00) & 0.40 (0.22) & \textbf{1.00 (0.00)}\\ + & & REVISE & 0.79 ± 0.19*\hphantom{*} & 1.45 ± 0.44\hphantom{*}\hphantom{*} & \textbf{1.64 ± 0.31}\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.40 ± 0.22\hphantom{*}\hphantom{*} & \textbf{1.00 ± 0.00}\hphantom{*}\hphantom{*}\\ - & & Schut & \textbf{0.73 (0.25)*} & 1.45 (0.55) & 1.73 (0.48) & \textbf{0.31 (0.28)*} & \textbf{0.00 (0.00)**} & 0.90 (0.30)\\ + & & Schut & \textbf{0.73 ± 0.25}*\hphantom{*} & 1.45 ± 0.55\hphantom{*}\hphantom{*} & 1.73 ± 0.48\hphantom{*}\hphantom{*} & \textbf{0.31 ± 0.28}*\hphantom{*} & \textbf{0.00 ± 0.00}** & 0.90 ± 0.30\hphantom{*}\hphantom{*}\\ -\multirow{-12}{*}{\centering\arraybackslash \textbf{Moons}} & \multirow{-6}{*}{\centering\arraybackslash MLP} & Wachter & 1.08 (0.83) & 1.32 (0.41) & 1.69 (0.32) & 0.00 (0.00) & 0.52 (0.08) & \textbf{1.00 (0.00)}\\ -\hline +\multirow[t]{-12}{*}{\centering\arraybackslash Moons} & \multirow[t]{-6}{*}{\centering\arraybackslash MLP} & Wachter & 1.08 ± 0.83\hphantom{*}\hphantom{*} & 1.32 ± 0.41\hphantom{*}\hphantom{*} & 1.69 ± 0.32\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.52 ± 0.08\hphantom{*}\hphantom{*} & \textbf{1.00 ± 0.00}\hphantom{*}\hphantom{*}\\ +\bottomrule \end{tabular}} \end{table} diff --git a/paper/contents/table_all_valid.tex b/paper/contents/table_all_valid.tex index 0d22bde074413425ed994b8ca3b1c627ce8f4be6..5d4e466bed79d0d056956aec587c32a05a23cd93 100644 --- a/paper/contents/table_all_valid.tex +++ b/paper/contents/table_all_valid.tex @@ -1,12 +1,12 @@ \begin{table} -\caption{All results for all datasets including only valid counterfactuals. Standard deviations across samples are shown in parentheses. Best outcomes are highlighted in bold. Asterisks indicate that the given value is more than one (*) or two (**) standard deviations away from the baseline (Wachter). \label{tab:results-full} \newline} +\caption{All results for all datasets: sample averages +/- one standard deviation over all valid counterfactuals. Best outcomes are highlighted in bold. Asterisks indicate that the given value is more than one (*) or two (**) standard deviations away from the baseline (Wachter). \label{tab:results-full} \newline} \centering \resizebox{\linewidth}{!}{ -\begin{tabular}[t]{>{}c|c|c|c|c|c|c|c|c} -\hline +\begin{tabular}[t]{ccccccccc} +\toprule Model & Data & Generator & Cost ↓ & Unfaithfulness ↓ & Implausibility ↓ & Redundancy ↑ & Uncertainty ↓ & Validity ↑\\ -\hline +\midrule & & ECCCo & 0.74 ± 0.21\hphantom{*}\hphantom{*} & 0.52 ± 0.36\hphantom{*}\hphantom{*} & 1.22 ± 0.46\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & ECCCo (no CP) & 0.72 ± 0.21\hphantom{*}\hphantom{*} & 0.54 ± 0.39\hphantom{*}\hphantom{*} & 1.21 ± 0.46\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ @@ -17,8 +17,8 @@ Model & Data & Generator & Cost ↓ & Unfaithfulness ↓ & Implausibility ↓ & & & Schut & 1.06 ± 0.43\hphantom{*}\hphantom{*} & 0.54 ± 0.43\hphantom{*}\hphantom{*} & 1.28 ± 0.53\hphantom{*}\hphantom{*} & \textbf{0.26 ± 0.25}*\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ - & \multirow{-6}{*}{\centering\arraybackslash JEM} & Wachter & \textbf{0.45 ± 0.15}\hphantom{*}\hphantom{*} & 0.68 ± 0.34\hphantom{*}\hphantom{*} & 1.33 ± 0.32\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ -\cline{2-9} + & \multirow[t]{-6}{*}{\centering\arraybackslash JEM} & Wachter & \textbf{0.45 ± 0.15}\hphantom{*}\hphantom{*} & 0.68 ± 0.34\hphantom{*}\hphantom{*} & 1.33 ± 0.32\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ +\cmidrule{2-9} & & ECCCo & 0.67 ± 0.19\hphantom{*}\hphantom{*} & 0.65 ± 0.53\hphantom{*}\hphantom{*} & 1.17 ± 0.41\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.09 ± 0.19** & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & ECCCo (no CP) & 0.71 ± 0.16\hphantom{*}\hphantom{*} & 0.49 ± 0.35\hphantom{*}\hphantom{*} & 1.19 ± 0.44\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.05 ± 0.16** & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ @@ -29,40 +29,40 @@ Model & Data & Generator & Cost ↓ & Unfaithfulness ↓ & Implausibility ↓ & & & Schut & 0.57 ± 0.11\hphantom{*}\hphantom{*} & 0.58 ± 0.37\hphantom{*}\hphantom{*} & 1.23 ± 0.43\hphantom{*}\hphantom{*} & \textbf{0.43 ± 0.18}** & \textbf{0.00 ± 0.00}** & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ -\multirow{-12}{*}{\centering\arraybackslash \textbf{Circles}} & \multirow{-6}{*}{\centering\arraybackslash MLP} & Wachter & \textbf{0.40 ± 0.09}\hphantom{*}\hphantom{*} & 0.83 ± 0.50\hphantom{*}\hphantom{*} & 1.24 ± 0.29\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.53 ± 0.01\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ -\cline{1-9} +\multirow[t]{-12}{*}{\centering\arraybackslash Circles} & \multirow[t]{-6}{*}{\centering\arraybackslash MLP} & Wachter & \textbf{0.40 ± 0.09}\hphantom{*}\hphantom{*} & 0.83 ± 0.50\hphantom{*}\hphantom{*} & 1.24 ± 0.29\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.53 ± 0.01\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ +\cmidrule{1-9} & & ECCCo & 19.20 ± 4.90** & \textbf{79.18 ± 13.01}** & 19.67 ± 6.27** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.09 ± 0.00\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & REVISE & 3.29 ± 1.59** & 186.05 ± 31.81\hphantom{*}\hphantom{*} & \textbf{5.38 ± 1.89}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.38 ± 0.20\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & Schut & \textbf{1.19 ± 0.70}** & 185.40 ± 38.43\hphantom{*}\hphantom{*} & 6.54 ± 0.98** & \textbf{0.73 ± 0.10}** & 0.09 ± 0.00\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ - & \multirow{-4}{*}{\centering\arraybackslash JEM} & Wachter & 68.49 ± 61.55\hphantom{*}\hphantom{*} & 188.81 ± 41.72\hphantom{*}\hphantom{*} & 71.97 ± 60.09\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.08 ± 0.00}\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ -\cline{2-9} + & \multirow[t]{-4}{*}{\centering\arraybackslash JEM} & Wachter & 68.49 ± 61.55\hphantom{*}\hphantom{*} & 188.81 ± 41.72\hphantom{*}\hphantom{*} & 71.97 ± 60.09\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.08 ± 0.00}\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ +\cmidrule{2-9} & & ECCCo & 16.90 ± 4.81** & \textbf{79.65 ± 11.83}** & 17.81 ± 5.44** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.17 ± 0.19\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & REVISE & 2.97 ± 0.95** & 204.14 ± 36.13\hphantom{*}\hphantom{*} & \textbf{4.90 ± 0.95}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.35 ± 0.18\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & Schut & \textbf{1.23 ± 0.30}** & 186.24 ± 36.18\hphantom{*}\hphantom{*} & 6.35 ± 1.22** & \textbf{0.66 ± 0.06}** & 0.13 ± 0.06\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ - & \multirow{-4}{*}{\centering\arraybackslash JEM Ensemble} & Wachter & 57.72 ± 49.41\hphantom{*}\hphantom{*} & 184.05 ± 23.11\hphantom{*}\hphantom{*} & 61.40 ± 48.29\hphantom{*}\hphantom{*} & 0.01 ± 0.02\hphantom{*}\hphantom{*} & \textbf{0.11 ± 0.02}\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ -\cline{2-9} + & \multirow[t]{-4}{*}{\centering\arraybackslash JEM Ensemble} & Wachter & 57.72 ± 49.41\hphantom{*}\hphantom{*} & 184.05 ± 23.11\hphantom{*}\hphantom{*} & 61.40 ± 48.29\hphantom{*}\hphantom{*} & 0.01 ± 0.02\hphantom{*}\hphantom{*} & \textbf{0.11 ± 0.02}\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ +\cmidrule{2-9} & & ECCCo & 23.22 ± 6.26** & \textbf{80.51 ± 16.59}** & 23.43 ± 6.09** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.14 ± 0.00\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & REVISE & 7.29 ± 12.81** & 180.18 ± 30.75\hphantom{*}\hphantom{*} & \textbf{5.05 ± 1.05}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.31 ± 0.14\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & Schut & \textbf{1.85 ± 1.08}** & 199.88 ± 45.58\hphantom{*}\hphantom{*} & 7.25 ± 1.88** & \textbf{0.74 ± 0.10}** & 0.14 ± 0.00\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ - & \multirow{-4}{*}{\centering\arraybackslash MLP} & Wachter & 85.89 ± 55.86\hphantom{*}\hphantom{*} & 196.33 ± 33.11\hphantom{*}\hphantom{*} & 87.52 ± 53.98\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.13 ± 0.00}\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ -\cline{2-9} + & \multirow[t]{-4}{*}{\centering\arraybackslash MLP} & Wachter & 85.89 ± 55.86\hphantom{*}\hphantom{*} & 196.33 ± 33.11\hphantom{*}\hphantom{*} & 87.52 ± 53.98\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.13 ± 0.00}\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ +\cmidrule{2-9} & & ECCCo & 22.45 ± 8.45\hphantom{*}\hphantom{*} & \textbf{76.32 ± 14.56}** & 22.99 ± 8.31\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.13 ± 0.00\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & REVISE & 3.16 ± 0.91** & 184.04 ± 29.13\hphantom{*}\hphantom{*} & \textbf{5.25 ± 1.31}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.27 ± 0.11\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & Schut & \textbf{0.61 ± 0.24}** & 214.74 ± 34.33\hphantom{*}\hphantom{*} & 6.18 ± 1.17** & \textbf{0.89 ± 0.03}** & 0.13 ± 0.00\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ -\multirow{-16}{*}{\centering\arraybackslash \textbf{GMSC}} & \multirow{-4}{*}{\centering\arraybackslash MLP Ensemble} & Wachter & 8.73 ± 6.23\hphantom{*}\hphantom{*} & 193.41 ± 35.45\hphantom{*}\hphantom{*} & 12.71 ± 4.90\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.13 ± 0.00}\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ -\cline{1-9} +\multirow[t]{-16}{*}{\centering\arraybackslash GMSC} & \multirow[t]{-4}{*}{\centering\arraybackslash MLP Ensemble} & Wachter & 8.73 ± 6.23\hphantom{*}\hphantom{*} & 193.41 ± 35.45\hphantom{*}\hphantom{*} & 12.71 ± 4.90\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.13 ± 0.00}\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ +\cmidrule{1-9} & & ECCCo & 0.75 ± 0.17\hphantom{*}\hphantom{*} & \textbf{0.03 ± 0.06}** & \textbf{0.20 ± 0.08}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & ECCCo (no CP) & 0.75 ± 0.17\hphantom{*}\hphantom{*} & 0.03 ± 0.06** & 0.20 ± 0.08** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ @@ -73,8 +73,8 @@ Model & Data & Generator & Cost ↓ & Unfaithfulness ↓ & Implausibility ↓ & & & Schut & 1.15 ± 0.35\hphantom{*}\hphantom{*} & 0.39 ± 0.07\hphantom{*}\hphantom{*} & 0.73 ± 0.17\hphantom{*}\hphantom{*} & \textbf{0.25 ± 0.25}\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ - & \multirow{-6}{*}{\centering\arraybackslash JEM} & Wachter & 0.50 ± 0.13\hphantom{*}\hphantom{*} & 0.18 ± 0.10\hphantom{*}\hphantom{*} & 0.44 ± 0.17\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ -\cline{2-9} + & \multirow[t]{-6}{*}{\centering\arraybackslash JEM} & Wachter & 0.50 ± 0.13\hphantom{*}\hphantom{*} & 0.18 ± 0.10\hphantom{*}\hphantom{*} & 0.44 ± 0.17\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ +\cmidrule{2-9} & & ECCCo & 0.95 ± 0.16\hphantom{*}\hphantom{*} & \textbf{0.29 ± 0.05}** & 0.23 ± 0.06** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}** & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & ECCCo (no CP) & 0.94 ± 0.16\hphantom{*}\hphantom{*} & 0.29 ± 0.05** & \textbf{0.23 ± 0.07}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}** & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ @@ -85,40 +85,40 @@ Model & Data & Generator & Cost ↓ & Unfaithfulness ↓ & Implausibility ↓ & & & Schut & 0.77 ± 0.17\hphantom{*}\hphantom{*} & 0.43 ± 0.06*\hphantom{*} & 0.47 ± 0.36\hphantom{*}\hphantom{*} & \textbf{0.20 ± 0.25}\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}** & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ -\multirow{-12}{*}{\centering\arraybackslash \textbf{Linearly Separable}} & \multirow{-6}{*}{\centering\arraybackslash MLP} & Wachter & 0.51 ± 0.15\hphantom{*}\hphantom{*} & 0.51 ± 0.04\hphantom{*}\hphantom{*} & 0.40 ± 0.08\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.59 ± 0.02\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ -\cline{1-9} +\multirow[t]{-12}{*}{\centering\arraybackslash Linearly Separable} & \multirow[t]{-6}{*}{\centering\arraybackslash MLP} & Wachter & 0.51 ± 0.15\hphantom{*}\hphantom{*} & 0.51 ± 0.04\hphantom{*}\hphantom{*} & 0.40 ± 0.08\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.59 ± 0.02\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ +\cmidrule{1-9} & & ECCCo & 334.98 ± 46.54\hphantom{*}\hphantom{*} & \textbf{19.27 ± 5.02}** & 314.54 ± 32.54*\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{4.50 ± 0.00}** & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & REVISE & 170.06 ± 62.45\hphantom{*}\hphantom{*} & 188.54 ± 26.22*\hphantom{*} & \textbf{254.32 ± 41.55}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 4.57 ± 0.14\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & Schut & \textbf{7.63 ± 2.55}** & 199.70 ± 28.43\hphantom{*}\hphantom{*} & 273.01 ± 39.60** & \textbf{0.99 ± 0.00}** & 4.56 ± 0.13\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ - & \multirow{-4}{*}{\centering\arraybackslash JEM} & Wachter & 128.13 ± 14.81\hphantom{*}\hphantom{*} & 222.81 ± 26.22\hphantom{*}\hphantom{*} & 361.38 ± 39.55\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 4.58 ± 0.16\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ -\cline{2-9} + & \multirow[t]{-4}{*}{\centering\arraybackslash JEM} & Wachter & 128.13 ± 14.81\hphantom{*}\hphantom{*} & 222.81 ± 26.22\hphantom{*}\hphantom{*} & 361.38 ± 39.55\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 4.58 ± 0.16\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ +\cmidrule{2-9} & & ECCCo & 342.64 ± 41.14\hphantom{*}\hphantom{*} & \textbf{15.99 ± 3.06}** & 294.72 ± 30.75** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{2.07 ± 0.06}** & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & REVISE & 171.95 ± 58.81\hphantom{*}\hphantom{*} & 173.05 ± 20.38** & \textbf{246.20 ± 37.74}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 2.76 ± 0.45\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & Schut & \textbf{7.96 ± 2.49}** & 186.91 ± 22.98*\hphantom{*} & 264.68 ± 37.58** & \textbf{0.99 ± 0.00}** & 3.02 ± 0.26\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ - & \multirow{-4}{*}{\centering\arraybackslash JEM Ensemble} & Wachter & 134.98 ± 16.95\hphantom{*}\hphantom{*} & 217.37 ± 23.93\hphantom{*}\hphantom{*} & 362.91 ± 39.40\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 3.10 ± 0.31\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ -\cline{2-9} + & \multirow[t]{-4}{*}{\centering\arraybackslash JEM Ensemble} & Wachter & 134.98 ± 16.95\hphantom{*}\hphantom{*} & 217.37 ± 23.93\hphantom{*}\hphantom{*} & 362.91 ± 39.40\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 3.10 ± 0.31\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ +\cmidrule{2-9} & & ECCCo & 605.17 ± 44.78\hphantom{*}\hphantom{*} & \textbf{41.95 ± 6.50}** & 591.58 ± 36.24\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.57 ± 0.00}** & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & REVISE & 146.76 ± 37.07\hphantom{*}\hphantom{*} & 365.69 ± 14.90*\hphantom{*} & 245.36 ± 39.69** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.72 ± 0.18\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & Schut & \textbf{9.25 ± 1.31}** & 371.12 ± 19.99\hphantom{*}\hphantom{*} & \textbf{245.11 ± 35.72}** & \textbf{0.99 ± 0.00}** & 0.75 ± 0.23\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ - & \multirow{-4}{*}{\centering\arraybackslash MLP} & Wachter & 135.08 ± 15.68\hphantom{*}\hphantom{*} & 384.76 ± 16.52\hphantom{*}\hphantom{*} & 359.21 ± 42.03\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.81 ± 0.22\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ -\cline{2-9} + & \multirow[t]{-4}{*}{\centering\arraybackslash MLP} & Wachter & 135.08 ± 15.68\hphantom{*}\hphantom{*} & 384.76 ± 16.52\hphantom{*}\hphantom{*} & 359.21 ± 42.03\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.81 ± 0.22\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ +\cmidrule{2-9} & & ECCCo & 525.87 ± 34.00\hphantom{*}\hphantom{*} & \textbf{31.43 ± 3.91}** & 490.88 ± 27.19\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.29 ± 0.00}** & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & REVISE & 146.38 ± 35.18\hphantom{*}\hphantom{*} & 337.21 ± 11.68*\hphantom{*} & \textbf{244.84 ± 37.17}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.45 ± 0.16\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & Schut & \textbf{9.75 ± 1.00}** & 344.60 ± 13.64*\hphantom{*} & 252.53 ± 37.92** & \textbf{0.99 ± 0.00}** & 0.55 ± 0.21\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ -\multirow{-16}{*}{\centering\arraybackslash \textbf{MNIST}} & \multirow{-4}{*}{\centering\arraybackslash MLP Ensemble} & Wachter & 134.48 ± 17.69\hphantom{*}\hphantom{*} & 358.51 ± 13.18\hphantom{*}\hphantom{*} & 352.63 ± 39.93\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.58 ± 0.67\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ -\cline{1-9} +\multirow[t]{-16}{*}{\centering\arraybackslash MNIST} & \multirow[t]{-4}{*}{\centering\arraybackslash MLP Ensemble} & Wachter & 134.48 ± 17.69\hphantom{*}\hphantom{*} & 358.51 ± 13.18\hphantom{*}\hphantom{*} & 352.63 ± 39.93\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.58 ± 0.67\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ +\cmidrule{1-9} & & ECCCo & 1.56 ± 0.44\hphantom{*}\hphantom{*} & \textbf{0.31 ± 0.30}*\hphantom{*} & \textbf{1.20 ± 0.15}** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}** & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & ECCCo (no CP) & 1.56 ± 0.46\hphantom{*}\hphantom{*} & 0.37 ± 0.30*\hphantom{*} & 1.21 ± 0.17** & 0.00 ± 0.00\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}** & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ @@ -129,8 +129,8 @@ Model & Data & Generator & Cost ↓ & Unfaithfulness ↓ & Implausibility ↓ & & & Schut & 1.13 ± 0.29\hphantom{*}\hphantom{*} & 0.66 ± 0.25\hphantom{*}\hphantom{*} & 1.47 ± 0.10** & \textbf{0.07 ± 0.18}\hphantom{*}\hphantom{*} & \textbf{0.00 ± 0.00}** & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ - & \multirow{-6}{*}{\centering\arraybackslash JEM} & Wachter & \textbf{0.73 ± 0.24}\hphantom{*}\hphantom{*} & 0.78 ± 0.23\hphantom{*}\hphantom{*} & 1.75 ± 0.19\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.02 ± 0.11\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ -\cline{2-9} + & \multirow[t]{-6}{*}{\centering\arraybackslash JEM} & Wachter & \textbf{0.73 ± 0.24}\hphantom{*}\hphantom{*} & 0.78 ± 0.23\hphantom{*}\hphantom{*} & 1.75 ± 0.19\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.02 ± 0.11\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ +\cmidrule{2-9} & & ECCCo & 2.18 ± 1.05\hphantom{*}\hphantom{*} & 0.80 ± 0.62\hphantom{*}\hphantom{*} & 1.69 ± 0.40\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.15 ± 0.24*\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ & & ECCCo (no CP) & 2.07 ± 1.15\hphantom{*}\hphantom{*} & \textbf{0.79 ± 0.62}\hphantom{*}\hphantom{*} & 1.68 ± 0.42\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.15 ± 0.24*\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ @@ -141,7 +141,7 @@ Model & Data & Generator & Cost ↓ & Unfaithfulness ↓ & Implausibility ↓ & & & Schut & \textbf{0.78 ± 0.17}*\hphantom{*} & 1.39 ± 0.50\hphantom{*}\hphantom{*} & \textbf{1.59 ± 0.26}\hphantom{*}\hphantom{*} & \textbf{0.28 ± 0.25}*\hphantom{*} & \textbf{0.00 ± 0.00}** & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ -\multirow{-12}{*}{\centering\arraybackslash \textbf{Moons}} & \multirow{-6}{*}{\centering\arraybackslash MLP} & Wachter & 1.08 ± 0.83\hphantom{*}\hphantom{*} & 1.32 ± 0.41\hphantom{*}\hphantom{*} & 1.69 ± 0.32\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.52 ± 0.08\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ -\hline +\multirow[t]{-12}{*}{\centering\arraybackslash Moons} & \multirow[t]{-6}{*}{\centering\arraybackslash MLP} & Wachter & 1.08 ± 0.83\hphantom{*}\hphantom{*} & 1.32 ± 0.41\hphantom{*}\hphantom{*} & 1.69 ± 0.32\hphantom{*}\hphantom{*} & 0.00 ± 0.00\hphantom{*}\hphantom{*} & 0.52 ± 0.08\hphantom{*}\hphantom{*} & 1.00 ± 0.00\hphantom{*}\hphantom{*}\\ +\bottomrule \end{tabular}} \end{table} diff --git a/paper/paper.pdf b/paper/paper.pdf index b46ccfa7086301083444e9773b952c2715ae3f93..06118315a41465bc9f4051211ab1cb3aa0088fe1 100644 Binary files a/paper/paper.pdf and b/paper/paper.pdf differ diff --git a/paper/paper.tex b/paper/paper.tex index 4732bc7de67ab3d1c3f0683abbab4e2feaf94919..909827ce8ebb8bc5427706dcf1aa56d204d53293 100644 --- a/paper/paper.tex +++ b/paper/paper.tex @@ -349,7 +349,7 @@ Since we were not able to identify any existing open-source software for Energy- \subsubsection{Training: Joint Energy Models} -To train our Joint Energy Models we broadly follow the approach outlined in~\citet{grathwohl2020your}. These models are trained to optimize a hybrid objective that involves a standard classification loss component $L_{\text{clf}}(\theta)=-\log p_{\theta}(\mathbf{y}|\mathbf{x})$ (e.g. crossentropy loss) as well as a generative loss component $L_{\text{gen}}(\theta)=-\log p_{\theta}(\mathbf{x})$. +To train our Joint Energy Models we broadly follow the approach outlined in~\citet{grathwohl2020your}. These models are trained to optimize a hybrid objective that involves a standard classification loss component $L_{\text{clf}}(\theta)=-\log p_{\theta}(\mathbf{y}|\mathbf{x})$ (e.g. cross-entropy loss) as well as a generative loss component $L_{\text{gen}}(\theta)=-\log p_{\theta}(\mathbf{x})$. To draw samples from $p_{\theta}(\mathbf{x})$, we rely exclusively on the conditional sampling approach described in~\citet{grathwohl2020your} for both training and inference: we first draw $\mathbf{y}\sim p(\mathbf{y})$ and then sample $\mathbf{x} \sim p_{\theta}(\mathbf{x}|\mathbf{y})$~\citep{grathwohl2020your} via Equation~\ref{eq:sgld} with energy $\mathcal{E}(\mathbf{x}|\mathbf{y})=\mu_{\theta}(\mathbf{x})[\mathbf{y}]$ where $\mu_{\theta}: \mathcal{X} \mapsto \mathbb{R}^K$ returns the linear predictions (logits) of our classifier $M_{\theta}$. While our package also supports unconditional sampling, we found conditional sampling to work well. It is also well aligned with CE, since in this context we are interested in conditioning on the target class. @@ -373,7 +373,19 @@ It is important to realise that sampling is done during each training epoch, whi where $L_{\text{reg}}(\theta)$ is a Ridge penalty (L2 norm) that regularises energy magnitudes for both observed and generated samples~\citep{du2020implicit}. We have used varying degrees of regularization depending on the dataset. -Contrary to existing work, we have not typically used the entire minibatch of training data for the generative loss component but found that using a subset of the minibatch was often sufficient in attaining decent generative performance. This has helped to reduce the computational burden for our models, which should make it easier for others to reproduce our findings. +Contrary to existing work, we have not typically used the entire minibatch of training data for the generative loss component but found that using a subset of the minibatch was often sufficient in attaining decent generative performance. This has helped to reduce the computational burden for our models, which should make it easier for others to reproduce our findings. Figures~\ref{fig:mnist-gen} and~\ref{fig:moons-gen} show generated samples for our \textit{MNIST} and \textit{Moons} data, to provide a sense of their generative property. + +\begin{figure} + \centering + \includegraphics[width=0.75\textwidth]{../artifacts/results/images/mnist_generated_JEM Ensemble.png} + \caption{Conditionally generated \textit{MNIST} images for our JEM Ensemble.}\label{fig:mnist-gen} +\end{figure} + +\begin{figure} + \centering + \includegraphics[width=0.5\textwidth]{../artifacts/results/images/moons_generated_JEM.png} + \caption{Conditionally generated samples (stars) for our \textit{Moons} data using a JEM.}\label{fig:moons-gen} +\end{figure} \subsubsection{Inference: Quantifying Models' Generative Property} @@ -405,7 +417,13 @@ Observe from Equation~\ref{eq:scp} that Conformal Prediction works on an instanc The fact that conformal classifiers produce set-valued predictions introduces a challenge: it is not immediately obvious how to use such classifiers in the context of gradient-based counterfactual search. Put differently, it is not clear how to use prediction sets in Equation~\ref{eq:general}. Fortunately, \citet{stutz2022learning} have recently proposed a framework for Conformal Training that also hinges on differentiability. Specifically, they show how Stochastic Gradient Descent can be used to train classifiers not only for the discriminative task but also for additional objectives related to Conformal Prediction. One such objective is \textit{efficiency}: for a given target error rate $\alpha$, the efficiency of a conformal classifier improves as its average prediction set size decreases. To this end, the authors introduce a smooth set size penalty defined in Equation~\ref{eq:setsize} in the body of this paper. Formally, it is defined as $C_{\theta,\mathbf{y}}(\mathbf{x}_i;\alpha):=\sigma\left((s(\mathbf{x}_i,\mathbf{y})-\alpha) T^{-1}\right)$ for $\mathbf{y}\in\mathcal{Y}$, where $\sigma$ is the sigmoid function and $T$ is a hyper-parameter used for temperature scaling~\citep{stutz2022learning}. -In addition to the smooth set size penalty,~\citet{stutz2022learning} also propose a configurable classification loss function, that can be used to enforce coverage. For \textit{MNIST} data, we found that using this function generally improved the visual quality of the generated counterfactuals, so we used it in our experiments involving real-world data. For the synthetic dataset, visual inspection of the counterfactuals showed that using the configurable loss function sometimes led to overshooting: counterfactuals would end up deep inside the target domain but far away from the observed samples. For this reason we instead relied on standard crossentropy loss for our synthetic datasets. As we have noted in the body of the paper, more experimental work is certainly needed in this context. +In addition to the smooth set size penalty,~\citet{stutz2022learning} also propose a configurable classification loss function, that can be used to enforce coverage. For \textit{MNIST} data, we found that using this function generally improved the visual quality of the generated counterfactuals, so we used it in our experiments involving real-world data. For the synthetic dataset, visual inspection of the counterfactuals showed that using the configurable loss function sometimes led to overshooting: counterfactuals would end up deep inside the target domain but far away from the observed samples. For this reason, we instead relied on standard cross-entropy loss for our synthetic datasets. As we have noted in the body of the paper, more experimental work is certainly needed in this context. Figure~\ref{fig:cp-diff} shows the prediction set size (left), smooth set size loss (centre) and configurable classification loss (right) for a JEM trained on our \textit{Linearly Separable} data. + +\begin{figure} + \centering + \includegraphics[width=1.0\textwidth]{../artifacts/results/images/poc_set_size.png} + \caption{Prediction set size (left), smooth set size loss (centre) and configurable classification loss (right) for a JEM trained on our \textit{Linearly Separable} data.}\label{fig:cp-diff} +\end{figure} \subsection{ECCCo}\label{app:eccco}