diff --git a/artifacts/results/cal_housing_models.jls b/artifacts/results/cal_housing_models.jls new file mode 100644 index 0000000000000000000000000000000000000000..d19ece952b9aa0d5ce52367d41e879c597848020 Binary files /dev/null and b/artifacts/results/cal_housing_models.jls differ diff --git a/notebooks/cal_housing.qmd b/notebooks/cal_housing.qmd new file mode 100644 index 0000000000000000000000000000000000000000..bea1cfa83b1662eda63eab353ccb25f3e081e72e --- /dev/null +++ b/notebooks/cal_housing.qmd @@ -0,0 +1,202 @@ +```{julia} +include("notebooks/setup.jl") +eval(setup_notebooks) +``` + +# Real-World Data + +```{julia} +# Hyper: +_retrain = true + +# Data: +n_obs = 10000 +counterfactual_data = load_california_housing(n_obs) +X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) +X = table(permutedims(X)) +labels = counterfactual_data.output_encoder.labels +input_dim, n_obs = size(counterfactual_data.X) +output_dim = length(unique(labels)) +``` + +First, let's create a couple of image classifier architectures: + +```{julia} +# Model parameters: +epochs = 100 +batch_size = minimum([Int(round(n_obs/10)), 128]) +n_hidden = 64 +activation = Flux.relu +builder = MLJFlux.@builder Flux.Chain( + Dense(n_in, n_hidden, activation), + Dense(n_hidden, n_out), +) +n_ens = 5 # number of models in ensemble +_loss = Flux.Losses.logitcrossentropy # loss function +_finaliser = x -> x # finaliser function +``` + +```{julia} +# JEM parameters: +ð’Ÿx = Normal() +ð’Ÿy = Categorical(ones(output_dim) ./ output_dim) +sampler = ConditionalSampler( + ð’Ÿx, ð’Ÿy, + input_size=(input_dim,), + batch_size=10, +) +α = [1.0,1.0,1e-2] # penalty strengths +``` + +```{julia} +# Simple MLP: +mlp = NeuralNetworkClassifier( + builder=builder, + epochs=epochs, + batch_size=batch_size, + finaliser=_finaliser, + loss=_loss, +) + +# Deep Ensemble: +mlp_ens = EnsembleModel(model=mlp, n=n_ens) + +# Joint Energy Model: +jem = JointEnergyClassifier( + sampler; + builder=builder, + epochs=epochs, + batch_size=batch_size, + finaliser=_finaliser, + loss=_loss, + jem_training_params=( + α=α,verbosity=10, + ), + sampling_steps=20, +) + +# JEM with adversarial training: +jem_adv = deepcopy(jem) +# jem_adv.adv_training = true + +# Deep Ensemble of Joint Energy Models: +jem_ens = EnsembleModel(model=jem, n=n_ens) + +# Deep Ensemble of Joint Energy Models with adversarial training: +# jem_ens_plus = EnsembleModel(model=jem_adv, n=n_ens) + +# Dictionary of models: +models = Dict( + "MLP" => mlp, + "MLP Ensemble" => mlp_ens, + "JEM" => jem, + "JEM Ensemble" => jem_ens, + # "JEM Ensemble+" => jem_ens_plus, +) +``` + + +```{julia} +# Train models: +function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_name="model") + conf_model = conformal_model(model; method=method, coverage=cov) + mach = machine(conf_model, X, y) + @info "Begin training $mod_name." + fit!(mach) + @info "Finished training $mod_name." + M = ECCCo.ConformalModel(mach.model, mach.fitresult) + return M +end +if _retrain + model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in models) + Serialization.serialize(joinpath(output_path,"cal_housing_models.jls"), model_dict) +else + model_dict = Serialization.deserialize(joinpath(output_path,"cal_housing_models.jls")) +end +``` + +```{julia} +# # Evaluate models: + +# measure = Dict( +# :f1score => multiclass_f1score, +# :acc => accuracy, +# :precision => multiclass_precision +# ) +# model_performance = DataFrame() +# for (mod_name, mod) in model_dict +# # Test performance: +# test_data = load_mnist_test() +# _perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure))) +# _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) +# _perf.mod_name .= mod_name +# model_performance = vcat(model_performance, _perf) +# end +# Serialization.serialize(joinpath(output_path,"cal_housing_model_performance.jls"), model_performance) +# CSV.write(joinpath(output_path, "cal_housing_model_performance.csv"), model_performance) +# model_performance +``` + +## Benchmark + +```{julia} +# Benchmark generators: +generator_dict = Dict( + :wachter => WachterGenerator(), + :revise => REVISEGenerator(), + :greedy => GreedyGenerator(), + :eccco => ECCCoGenerator(), +) + +# Measures: +measures = [ + CounterfactualExplanations.distance, + ECCCo.distance_from_energy, + ECCCo.distance_from_targets, + CounterfactualExplanations.Evaluation.validity, + CounterfactualExplanations.Evaluation.redundancy, +] + +bmk = benchmark( + counterfactual_data; + models=model_dict, + generators=generator_dict, + measure=measures, + suppress_training=true, dataname="California Housing", + n_individuals=5, + factual=0, target=1, + initialization=:identity, +) +CSV.write(joinpath(output_path, "cal_housing_benchmark.csv"), bmk()) +``` + + +```{julia} +@chain bmk() begin + @group_by(dataname, generator, model, variable) + @summarize(mean=mean(value),sd=std(value)) + @ungroup + @filter(variable == "distance_from_energy") +end +``` + + +```{julia} +df = @chain bmk() begin + @filter(variable in [ + "distance_from_energy", + "distance_from_targets", + "distance",]) + @mutate(variable = ifelse.(variable .== "distance_from_energy", "Non-Conformity", variable)) + @mutate(variable = ifelse.(variable .== "distance_from_targets", "Implausibility", variable)) + @mutate(variable = ifelse.(variable .== "distance", "Cost", variable)) +end +plt = AlgebraOfGraphics.data(df) * visual(BoxPlot) * + mapping(:generator, :value, row=:variable, col=:model, color=:generator) +plt = draw( + plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=150, height=120), + facet=(; linkyaxes=:minimal) +) +display(plt) +save(joinpath(output_images_path, "cal_housing_benchmark.png"), plt, px_per_unit=5) +``` \ No newline at end of file diff --git a/notebooks/gmsc.qmd b/notebooks/gmsc.qmd new file mode 100644 index 0000000000000000000000000000000000000000..fa63937a6b0051f1734b6ab2c2cfab5e78099993 --- /dev/null +++ b/notebooks/gmsc.qmd @@ -0,0 +1,202 @@ +```{julia} +include("notebooks/setup.jl") +eval(setup_notebooks) +``` + +# Real-World Data + +```{julia} +# Hyper: +_retrain = true + +# Data: +n_obs = 10000 +counterfactual_data = load_california_housing(n_obs) +X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) +X = table(permutedims(X)) +labels = counterfactual_data.output_encoder.labels +input_dim, n_obs = size(counterfactual_data.X) +output_dim = length(unique(labels)) +``` + +First, let's create a couple of image classifier architectures: + +```{julia} +# Model parameters: +epochs = 100 +batch_size = minimum([Int(round(n_obs/10)), 128]) +n_hidden = 64 +activation = Flux.relu +builder = MLJFlux.@builder Flux.Chain( + Dense(n_in, n_hidden, activation), + Dense(n_hidden, n_out), +) +n_ens = 5 # number of models in ensemble +_loss = Flux.Losses.logitcrossentropy # loss function +_finaliser = x -> x # finaliser function +``` + +```{julia} +# JEM parameters: +ð’Ÿx = Normal() +ð’Ÿy = Categorical(ones(output_dim) ./ output_dim) +sampler = ConditionalSampler( + ð’Ÿx, ð’Ÿy, + input_size=(input_dim,), + batch_size=10, +) +α = [1.0,1.0,1e-2] # penalty strengths +``` + +```{julia} +# Simple MLP: +mlp = NeuralNetworkClassifier( + builder=builder, + epochs=epochs, + batch_size=batch_size, + finaliser=_finaliser, + loss=_loss, +) + +# Deep Ensemble: +mlp_ens = EnsembleModel(model=mlp, n=n_ens) + +# Joint Energy Model: +jem = JointEnergyClassifier( + sampler; + builder=builder, + epochs=epochs, + batch_size=batch_size, + finaliser=_finaliser, + loss=_loss, + jem_training_params=( + α=α,verbosity=10, + ), + sampling_steps=20, +) + +# JEM with adversarial training: +jem_adv = deepcopy(jem) +# jem_adv.adv_training = true + +# Deep Ensemble of Joint Energy Models: +jem_ens = EnsembleModel(model=jem, n=n_ens) + +# Deep Ensemble of Joint Energy Models with adversarial training: +# jem_ens_plus = EnsembleModel(model=jem_adv, n=n_ens) + +# Dictionary of models: +models = Dict( + "MLP" => mlp, + "MLP Ensemble" => mlp_ens, + "JEM" => jem, + "JEM Ensemble" => jem_ens, + # "JEM Ensemble+" => jem_ens_plus, +) +``` + + +```{julia} +# Train models: +function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_name="model") + conf_model = conformal_model(model; method=method, coverage=cov) + mach = machine(conf_model, X, y) + @info "Begin training $mod_name." + fit!(mach) + @info "Finished training $mod_name." + M = ECCCo.ConformalModel(mach.model, mach.fitresult) + return M +end +if _retrain + model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in models) + Serialization.serialize(joinpath(output_path,"gmsc_models.jls"), model_dict) +else + model_dict = Serialization.deserialize(joinpath(output_path,"gmsc_models.jls")) +end +``` + +```{julia} +# # Evaluate models: + +# measure = Dict( +# :f1score => multiclass_f1score, +# :acc => accuracy, +# :precision => multiclass_precision +# ) +# model_performance = DataFrame() +# for (mod_name, mod) in model_dict +# # Test performance: +# test_data = load_mnist_test() +# _perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure))) +# _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) +# _perf.mod_name .= mod_name +# model_performance = vcat(model_performance, _perf) +# end +# Serialization.serialize(joinpath(output_path,"gmsc_model_performance.jls"), model_performance) +# CSV.write(joinpath(output_path, "gmsc_model_performance.csv"), model_performance) +# model_performance +``` + +## Benchmark + +```{julia} +# Benchmark generators: +generator_dict = Dict( + :wachter => WachterGenerator(), + :revise => REVISEGenerator(), + :greedy => GreedyGenerator(), + :eccco => ECCCoGenerator(), +) + +# Measures: +measures = [ + CounterfactualExplanations.distance, + ECCCo.distance_from_energy, + ECCCo.distance_from_targets, + CounterfactualExplanations.Evaluation.validity, + CounterfactualExplanations.Evaluation.redundancy, +] + +bmk = benchmark( + counterfactual_data; + models=model_dict, + generators=generator_dict, + measure=measures, + suppress_training=true, dataname="Californian Housing", + n_individuals=5, + factual=0, target=1, + initialization=:identity, +) +CSV.write(joinpath(output_path, "gmsc_benchmark.csv"), bmk()) +``` + + +```{julia} +@chain bmk() begin + @group_by(dataname, generator, model, variable) + @summarize(mean=mean(value),sd=std(value)) + @ungroup + @filter(variable == "distance_from_energy") +end +``` + + +```{julia} +df = @chain bmk() begin + @filter(variable in [ + "distance_from_energy", + "distance_from_targets", + "distance",]) + @mutate(variable = ifelse.(variable .== "distance_from_energy", "Non-Conformity", variable)) + @mutate(variable = ifelse.(variable .== "distance_from_targets", "Implausibility", variable)) + @mutate(variable = ifelse.(variable .== "distance", "Cost", variable)) +end +plt = AlgebraOfGraphics.data(df) * visual(BoxPlot) * + mapping(:generator, :value, row=:variable, col=:model, color=:generator) +plt = draw( + plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=150, height=120), + facet=(; linkyaxes=:minimal) +) +display(plt) +save(joinpath(output_images_path, "gmsc_benchmark.png"), plt, px_per_unit=5) +``` \ No newline at end of file diff --git a/notebooks/real_world.qmd b/notebooks/real_world.qmd deleted file mode 100644 index 465dce5171b9b631413a9d6c027075ada68a04f6..0000000000000000000000000000000000000000 --- a/notebooks/real_world.qmd +++ /dev/null @@ -1,420 +0,0 @@ -```{julia} -include("notebooks/setup.jl") -eval(setup_notebooks) -``` - -# Real-World Data - -```{julia} -# Hyper: -_retrain = false -_regen = false - -# Data: -n_obs = 10000 -datasets = load_tabular_data(n_obs; drop=:credit_default) -X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) -X = table(permutedims(X)) -x_factual = reshape(pre_process(x_factual, noise=0.0f0), input_dim, 1) -labels = counterfactual_data.output_encoder.labels -input_dim, n_obs = size(counterfactual_data.X) -n_digits = Int(sqrt(input_dim)) -output_dim = length(unique(labels)) -``` - -First, let's create a couple of image classifier architectures: - -```{julia} -# Model parameters: -epochs = 100 -batch_size = minimum([Int(round(n_obs/10)), 128]) -n_hidden = 64 -activation = Flux.relu -builder = MLJFlux.@builder Flux.Chain( - Dense(n_in, n_hidden, activation), - Dense(n_hidden, n_out), -) -n_ens = 5 # number of models in ensemble -_loss = Flux.Losses.logitcrossentropy # loss function -_finaliser = x -> x # finaliser function -``` - -```{julia} -# JEM parameters: -ð’Ÿx = Uniform(0,1) -ð’Ÿy = Categorical(ones(output_dim) ./ output_dim) -sampler = ConditionalSampler( - ð’Ÿx, ð’Ÿy, - input_size=(input_dim,), - batch_size=10, -) -α = [1.0,1.0,1e-2] # penalty strengths -``` - -```{julia} -# Simple MLP: -mlp = NeuralNetworkClassifier( - builder=builder, - epochs=epochs, - batch_size=batch_size, - finaliser=_finaliser, - loss=_loss, -) - -# Deep Ensemble: -mlp_ens = EnsembleModel(model=mlp, n=n_ens) - -# Joint Energy Model: -jem = JointEnergyClassifier( - sampler; - builder=builder, - epochs=epochs, - batch_size=batch_size, - finaliser=_finaliser, - loss=_loss, - jem_training_params=( - α=α,verbosity=10, - ), - sampling_steps=20, -) - -# JEM with adversarial training: -jem_adv = deepcopy(jem) -# jem_adv.adv_training = true - -# Deep Ensemble of Joint Energy Models: -jem_ens = EnsembleModel(model=jem, n=n_ens) - -# Deep Ensemble of Joint Energy Models with adversarial training: -# jem_ens_plus = EnsembleModel(model=jem_adv, n=n_ens) - -# Dictionary of models: -models = Dict( - "MLP" => mlp, - "MLP Ensemble" => mlp_ens, - "JEM" => jem, - "JEM Ensemble" => jem_ens, - # "JEM Ensemble+" => jem_ens_plus, -) -``` - - -```{julia} -# Train models: -function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_name="model") - conf_model = conformal_model(model; method=method, coverage=cov) - mach = machine(conf_model, X, y) - @info "Begin training $mod_name." - fit!(mach) - @info "Finished training $mod_name." - M = ECCCo.ConformalModel(mach.model, mach.fitresult) - return M -end -if _retrain - model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in models) - Serialization.serialize(joinpath(output_path,"mnist_models.jls"), model_dict) -else - model_dict = Serialization.deserialize(joinpath(output_path,"mnist_models.jls")) -end -``` - -```{julia} -# Plot generated samples: -n_regen = 150 -if _regen - for (mod_name, mod) in model_dict - if ECCCo._has_sampler(mod) - sampler = ECCCo._get_sampler(mod) - else - K = length(counterfactual_data.y_levels) - input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1)) - ð’Ÿx = Uniform(extrema(counterfactual_data.X)...) - ð’Ÿy = Categorical(ones(K) ./ K) - sampler = ConditionalSampler(ð’Ÿx, ð’Ÿy; input_size=input_size) - end - opt = ImproperSGLD() - f(x) = logits(mod, x) - - _w = 1500 - plts = [] - neach = 10 - for i in 1:10 - x = sampler(f, opt; niter=n_regen, n_samples=neach, y=i) - plts_i = [] - for j in 1:size(x, 2) - xj = x[:,j] - xj = reshape(xj, (n_digits, n_digits)) - plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)] - end - plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10)) - plts = [plts..., plt] - end - plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1), plot_title=mod_name) - savefig(plt, joinpath(output_images_path, "mnist_generated_$(mod_name).png")) - display(plt) - end -end -``` - -```{julia} -# Evaluate models: - -measure = Dict( - :f1score => multiclass_f1score, - :acc => accuracy, - :precision => multiclass_precision -) -model_performance = DataFrame() -for (mod_name, mod) in model_dict - # Test performance: - test_data = load_mnist_test() - _perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure))) - _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) - _perf.mod_name .= mod_name - model_performance = vcat(model_performance, _perf) -end -Serialization.serialize(joinpath(output_path,"mnist_model_performance.jls"), model_performance) -CSV.write(joinpath(output_path, "mnist_model_performance.csv"), model_performance) -model_performance -``` - -### Different Models - -```{julia} -function _plot_eccco_mnist( - x::Union{AbstractArray, Int}=x_factual, target::Int=target; - λ=[0.5,0.1,0.5], - temp=0.1,η=0.01, - plt_order = ["MLP", "MLP Ensemble", "JEM", "JEM Ensemble"], - opt = Flux.Optimise.Adam(η), - rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG, -) - - # Setup: - Random.seed!(rng) - if x isa Int - x = reshape(counterfactual_data.X[:,rand(findall(labels.==x))],input_dim,1) - end - - # Generate counterfactuals using ECCCo generator: - eccco_generator = ECCCoGenerator( - λ=λ, - temp=temp, - opt=opt, - ) - - ces = Dict() - for (mod_name, mod) in model_dict - ce = generate_counterfactual( - x, target, counterfactual_data, mod, eccco_generator; - decision_threshold=γ, max_iter=T, - initialization=:identity, - converge_when=:generator_conditions, - ) - ces[mod_name] = ce - end - _plt_order = map(x -> findall(collect(keys(model_dict)) .== x)[1], plt_order) - - # Plot: - p1 = Plots.plot( - convert2image(MNIST, reshape(x,28,28)), - axis=nothing, - size=(img_height, img_height), - title="Factual" - ) - - plts = [] - for (_name,ce) in ces - _x = CounterfactualExplanations.counterfactual(ce) - _phat = target_probs(ce) - _title = "$_name (pÌ‚=$(round(_phat[1]; digits=3)))" - plt = Plots.plot( - convert2image(MNIST, reshape(_x,28,28)), - axis=nothing, - size=(img_height, img_height), - title=_title - ) - plts = [plts..., plt] - end - plts = plts[_plt_order] - plts = [p1, plts...] - plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) - - return plt, eccco_generator -end -``` - -```{julia} -plt, eccco_generator = _plot_eccco_mnist() -display(plt) -savefig(plt, joinpath(output_images_path, "mnist_eccco.png")) -``` - -### All digits - -```{julia} -function plot_mnist( - factual::Int,target::Int; - generator::AbstractGenerator, - model::AbstractFittedModel=model_dict["JEM Ensemble"], - data::CounterfactualData=counterfactual_data, - rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG, - _plot_title::Bool=true, - kwargs..., -) - decision_threshold = !isdefined(kwargs, :decision_threshold) ? 0.9 : decision_threshold - max_iter = !isdefined(kwargs, :max_iter) ? 100 : max_iter - initialization = !isdefined(kwargs, :initialization) ? :identity : initialization - converge_when = !isdefined(kwargs, :converge_when) ? :generator_conditions : converge_when - - x = reshape(data.X[:,rand(findall(predict_label(model, data).==factual))],input_dim,1) - ce = generate_counterfactual( - x, target, data, model, generator; - decision_threshold=decision_threshold, max_iter=max_iter, - initialization=initialization, - converge_when=converge_when, - kwargs... - ) - - _title = _plot_title ? "$(factual) -> $(target)" : "" - - _x = CounterfactualExplanations.counterfactual(ce) - plt = Plots.plot( - convert2image(MNIST, reshape(_x,28,28)), - axis=nothing, - size=(img_height, img_height), - title=_title - ) - return plt -end -``` - -```{julia} -if _regen - function plot_all_digits(rng=1;verbose=true,kwargs...) - plts = [] - for i in 0:9 - for j in 0:9 - @info "Generating counterfactual for $(i) -> $(j)" - plt = plot_mnist(i,j;kwargs...,rng=rng) - !verbose || display(plt) - plts = [plts..., plt] - end - end - plt = Plots.plot(plts...; size=(img_height*10,img_height*10), layout=(10,10)) - return plt - end - plt = plot_all_digits(generator=eccco_generator) - savefig(plt, joinpath(output_images_path, "mnist_eccco_all_digits.png")) -end -``` - -### Different Generators - -```{julia} -# Setup: -model = model_dict["JEM Ensemble"] - -# Benchmark generators: -generator_dict = Dict( - :wachter => generic_generator, - :revise => revise_generator, - :greedy => greedy_generator, - :eccco => eccco_generator, -) - -ces = Dict() -for (gen_name, gen) in generator_dict - ce = generate_counterfactual( - x_factual, target, counterfactual_data, model, gen; - decision_threshold=γ, max_iter=T, - initialization=:identity, - converge_when=:generator_conditions, - ) - ces[gen_name] = ce -end -plt_order = sortperm(collect(keys(ces))) - -# Plot: -p1 = Plots.plot( - convert2image(MNIST, reshape(x_factual,28,28)), - axis=nothing, - size=(img_height, img_height), - title="Factual" -) - -plts = [] -for (_name,ce) in ces - _x = CounterfactualExplanations.counterfactual(ce) - _phat = target_probs(ce) - _title = "$_name (pÌ‚=$(round(_phat[1]; digits=3)))" - plt = Plots.plot( - convert2image(MNIST, reshape(_x,28,28)), - axis=nothing, - size=(img_height, img_height), - title=_title - ) - plts = [plts..., plt] -end -plts = plts[plt_order] -plts = [p1, plts...] -plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) -display(plt) -savefig(plt, joinpath(output_images_path, "mnist_all_generators.png")) -``` - -## Benchmark - -```{julia} -# Measures: -measures = [ - CounterfactualExplanations.distance, - ECCCo.distance_from_energy, - ECCCo.distance_from_targets, - CounterfactualExplanations.Evaluation.validity, - CounterfactualExplanations.Evaluation.redundancy, -] - -bmk = benchmark( - counterfactual_data; - models=model_dict, - generators=generator_dict, - measure=measures, - suppress_training=true, dataname="MNIST", - n_individuals=5, - factual=0, target=1, - initialization=:identity, -) -CSV.write(joinpath(output_path, "mnist_benchmark.csv"), bmk()) -``` - - -```{julia} -@chain bmk() begin - @group_by(dataname, generator, model, variable) - @summarize(mean=mean(value),sd=std(value)) - @ungroup - @filter(variable == "distance_from_energy") -end -``` - - -```{julia} -df = @chain bmk() begin - @filter(variable in [ - "distance_from_energy", - "distance_from_targets", - "distance",]) - @mutate(variable = ifelse.(variable .== "distance_from_energy", "Non-Conformity", variable)) - @mutate(variable = ifelse.(variable .== "distance_from_targets", "Implausibility", variable)) - @mutate(variable = ifelse.(variable .== "distance", "Cost", variable)) -end -plt = AlgebraOfGraphics.data(df) * visual(BoxPlot) * - mapping(:generator, :value, row=:variable, col=:model, color=:generator) -plt = draw( - plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=150, height=120), - facet=(; linkyaxes=:minimal) -) -display(plt) -save(joinpath(output_images_path, "mnist_benchmark.png"), plt, px_per_unit=5) -``` \ No newline at end of file