```{julia} include("$(pwd())/notebooks/setup.jl") eval(setup_notebooks) ``` # MNIST ## Anecdotal Evidence ### Examples in Introduction #### Wachter and JSMA ```{julia} Random.seed!(2023) # Data: counterfactual_data = load_mnist() X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) input_dim, n_obs = size(counterfactual_data.X) M = load_mnist_mlp() # Target: factual_label = 9 x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) target = 7 factual = predict_label(M, counterfactual_data, x_factual)[1] γ = 0.9 # Training params: T = 100 ``` ```{julia} # Search: generic_generator = WachterGenerator() ce_wachter = generate_counterfactual( x_factual, target, counterfactual_data, M, generic_generator; decision_threshold=γ, max_iter=T, initialization=:identity, ) greedy_generator = GreedyGenerator(η=2.0) ce_jsma = generate_counterfactual( x_factual, target, counterfactual_data, M, greedy_generator; decision_threshold=γ, max_iter=T, initialization=:identity, ) ``` ```{julia} p1 = Plots.plot( convert2image(MNIST, reshape(x_factual,28,28)), axis=([], false), size=(img_height, img_height), title="Factual" ) plts = [p1] ces = zip([ce_wachter,ce_jsma]) counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces)) phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces)) for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3), ["Wachter","JSMA"]) ce, _phat, _name = (x[1],x[2],x[3]) _title = "$(_name) (p=$(round(_phat[1]; digits=2)))" plt = Plots.plot( convert2image(MNIST, reshape(ce,28,28)), axis=([], false), size=(img_height, img_height), title=_title ) plts = [plts..., plt] end plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) display(plt) savefig(plt, joinpath(output_images_path, "you_may_not_like_it.png")) ``` #### REVISE ```{julia} using CounterfactualExplanations.Models: load_mnist_vae vae = load_mnist_vae() vae_weak = load_mnist_vae(;strong=false) Serialization.serialize(joinpath(output_path,"mnist_classifier.jls"), M) Serialization.serialize(joinpath(output_path,"mnist_vae.jls"), vae) Serialization.serialize(joinpath(output_path,"mnist_vae_weak.jls"), vae_weak) ``` ```{julia} # Define generator: revise_generator = REVISEGenerator( opt = Flux.Optimise.Descent(0.25), λ=0.0, ) # Generate recourse: counterfactual_data.generative_model = vae # assign generative model ce_strong = generate_counterfactual( x_factual, target, counterfactual_data, M, revise_generator; decision_threshold=γ, max_iter=T, initialization=:identity, converge_when=:generator_conditions, ) counterfactual_data_weak = deepcopy(counterfactual_data) counterfactual_data_weak.generative_model = vae_weak ce_weak = generate_counterfactual( x_factual, target, counterfactual_data_weak, M, revise_generator; decision_threshold=γ, max_iter=T, initialization=:identity, converge_when=:generator_conditions, ) ``` ```{julia} ces = zip([ce_strong,ce_weak]) counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces)) phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces)) plts = [p1] for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3), ["Strong VAE","Weak VAE"]) ce, _phat, _name = (x[1],x[2],x[3]) _title = "$(_name) (p=$(round(_phat[1]; digits=2)))" plt = Plots.plot( convert2image(MNIST, reshape(ce,28,28)), axis=([], false), size=(img_height, img_height), title=_title ) plts = [plts..., plt] end plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) display(plt) savefig(plt, joinpath(output_images_path, "surrogate_gone_wrong.png")) ``` ```{julia} ces = zip([ce_wachter, ce_jsma, ce_strong]) counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces)) phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces)) plts = [p1] for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3), ["Wachter","Schut","REVISE"]) ce, _phat, _name = (x[1],x[2],x[3]) _title = "$(_name) (p=$(round(_phat[1]; digits=2)))" plt = Plots.plot( convert2image(MNIST, reshape(ce,28,28)), axis=([], false), size=(img_height, img_height), title=_title ) plts = [plts..., plt] end plt = Plots.plot(plts...; size=(0.8*panel_height*length(plts),0.8*panel_height), layout=(1,length(plts)), dpi=400) display(plt) savefig(plt, joinpath(output_images_path, "mnist_motivation.png")) ``` ### ECCCo ```{julia} function pre_process(x; noise::Float32=0.03f0) ϵ = Float32.(randn(size(x)) * noise) x += ϵ return x end ``` ```{julia} # Hyper: _retrain = false _regen = false # Data: n_obs = 10000 counterfactual_data = load_mnist(n_obs) counterfactual_data.X = pre_process.(counterfactual_data.X) counterfactual_data.generative_model = vae X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) X = table(permutedims(X)) x_factual = reshape(pre_process(x_factual, noise=0.0f0), input_dim, 1) labels = counterfactual_data.output_encoder.labels input_dim, n_obs = size(counterfactual_data.X) n_digits = Int(sqrt(input_dim)) output_dim = length(unique(labels)) ``` First, let's create a couple of image classifier architectures: ```{julia} # Model parameters: epochs = 25 batch_size = minimum([Int(round(n_obs/10)), 128]) n_hidden = 128 activation = Flux.swish builder = MLJFlux.@builder Flux.Chain( Dense(n_in, n_hidden, activation), Dense(n_hidden, n_out), ) n_ens = 5 # number of models in ensemble _loss = Flux.Losses.crossentropy # loss function _finaliser = Flux.softmax # finaliser function ``` ```{julia} # JEM parameters: 𝒟x = Uniform(-1.0,1.0) 𝒟y = Categorical(ones(output_dim) ./ output_dim) sampler = ConditionalSampler( 𝒟x, 𝒟y, input_size=(input_dim,), batch_size=10, ) α = [1.0,1.0,1e-2] # penalty strengths ``` ```{julia} # Simple MLP: mlp = NeuralNetworkClassifier( builder=builder, epochs=epochs, batch_size=batch_size, finaliser=_finaliser, loss=_loss, ) # Deep Ensemble: mlp_ens = EnsembleModel(model=mlp, n=n_ens) # Joint Energy Model: jem = JointEnergyClassifier( sampler; builder=builder, epochs=epochs, batch_size=batch_size, finaliser=_finaliser, loss=_loss, jem_training_params=( α=α,verbosity=10, ), sampling_steps=25, ) # JEM with adversarial training: jem_adv = deepcopy(jem) # jem_adv.adv_training = true # Deep Ensemble of Joint Energy Models: jem_ens = EnsembleModel(model=jem, n=n_ens) # Deep Ensemble of Joint Energy Models with adversarial training: # jem_ens_plus = EnsembleModel(model=jem_adv, n=n_ens) # Dictionary of models: models = Dict( "MLP" => mlp, "MLP Ensemble" => mlp_ens, "JEM" => jem, "JEM Ensemble" => jem_ens, # "JEM Ensemble+" => jem_ens_plus, ) Serialization.serialize(joinpath(output_path,"mnist_architectures.jls"), models) ``` ```{julia} # Train models: function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_name="model") conf_model = conformal_model(model; method=method, coverage=cov) mach = machine(conf_model, X, y) @info "Begin training $mod_name." fit!(mach) @info "Finished training $mod_name." M = ECCCo.ConformalModel(mach.model, mach.fitresult) return M end if _retrain model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in models) Serialization.serialize(joinpath(output_path,"mnist_models.jls"), model_dict) else model_dict = Serialization.deserialize(joinpath(output_path,"mnist_models.jls")) end ``` ```{julia} params = DataFrame( Dict( :n_obs => Int.(round(n_obs/10)*10), :epochs => epochs, :batch_size => batch_size, :n_hidden => n_hidden, :n_layers => length(model_dict["MLP"].fitresult[1][1])-1, :activation => string(activation), :n_ens => n_ens, :lambda => string(α[3]), :jem_sampling_steps => jem.sampling_steps, :sgld_batch_size => sampler.batch_size, :dataname => "MNIST", ) ) CSV.write(joinpath(params_path, "mnist.csv"), params) ``` ```{julia} # Plot generated samples: n_regen = 500 if _regen for (mod_name, mod) in model_dict K = length(counterfactual_data.y_levels) input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1)) 𝒟x = Uniform(extrema(counterfactual_data.X)...) 𝒟y = Categorical(ones(K) ./ K) sampler = ConditionalSampler(𝒟x, 𝒟y; input_size=input_size, prob_buffer=0.0) opt = ImproperSGLD() f(x) = logits(mod, x) _w = 1500 plts = [] neach = 10 for i in 1:10 x = sampler(f, opt; niter=n_regen, n_samples=neach, y=i) plts_i = [] for j in 1:size(x, 2) xj = x[:,j] xj = reshape(xj, (n_digits, n_digits)) plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)] end plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10)) plts = [plts..., plt] end plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1), plot_title=mod_name) savefig(plt, joinpath(output_images_path, "mnist_generated_$(mod_name).png")) display(plt) end end ``` ```{julia} # Evaluate models: measure = Dict( :f1score => multiclass_f1score, :acc => accuracy, :precision => multiclass_precision ) model_performance = DataFrame() for (mod_name, mod) in model_dict # Test performance: test_data = load_mnist_test() _perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure))) _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) _perf.mod_name .= mod_name _perf.dataname .= "MNIST" model_performance = vcat(model_performance, _perf) end Serialization.serialize(joinpath(output_path,"mnist_model_performance.jls"), model_performance) CSV.write(joinpath(output_path, "mnist_model_performance.csv"), model_performance) model_performance ``` ### Different Models ```{julia} function plot_mnist(ce; size=(img_height, img_height), kwrgs...) x = CounterfactualExplanations.counterfactual(ce) phat = target_probs(ce) plt = Plots.plot( convert2image(MNIST, reshape(x,28,28)); axis=([], false), size=size, kwrgs..., ) end function _plot_eccco_mnist( x::Union{AbstractArray, Int}=x_factual, target::Int=target; λ=[0.1,0.25,0.25], temp=0.1, plt_order = ["MLP", "MLP Ensemble", "JEM", "JEM Ensemble"], opt = nothing, rng::Union{Int,AbstractRNG}=1234, T::Int = 100, use_class_loss::Bool = true, model_dict=model_dict, wide::Bool = false, img_height::Int = img_height, plot_factual::Bool = false, generator::Union{Nothing,CounterfactualExplanations.AbstractGenerator}=nothing, test_data::Bool = false, use_energy_delta::Bool = false, kwrgs..., ) # Setup: Random.seed!(rng) if x isa Int x_fact = counterfactual_data.X[:,rand(findall(labels.==x))][:,:] else x_fact = x end if isnothing(generator) # Generate counterfactuals using ECCCo generator: generator = ECCCoGenerator( λ=λ, temp=temp, opt=opt, use_class_loss=use_class_loss, nsamples=10, nmin=10, use_energy_delta=use_energy_delta, ) end if test_data data = load_mnist_test() else data = counterfactual_data end ces = Dict() for (mod_name, mod) in model_dict ce = generate_counterfactual( x_fact, target, data, mod, generator; decision_threshold=γ, max_iter=T, initialization=:identity, converge_when=:generator_conditions, ) ces[mod_name] = ce end _plt_order = map(x -> findall(collect(keys(model_dict)) .== x)[1], plt_order) # Plot: p1 = Plots.plot( convert2image(MNIST, reshape(x_fact,28,28)), axis=([], false), size=(img_height, img_height), title="Factual" ) if plot_factual plts = [p1] else plts = [] end letters = collect('a':'z')[1:length(ces)] _count = 1 for (_name,ce) in collect(ces)[_plt_order] _x = CounterfactualExplanations.counterfactual(ce) _phat = target_probs(ce) _title = "($(letters[_count]))" plt = Plots.plot( convert2image(MNIST, reshape(_x,28,28)), axis=([], false), size=(img_height, img_height), title=_title ) plts = [plts..., plt] _count += 1 end if wide plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)), kwrgs...) else plt = Plots.plot(plts...; size=(img_height,img_height), kwrgs...) end return plt, generator, ces end ``` ```{julia} plt, eccco_generator, ces = _plot_eccco_mnist() display(plt) savefig(plt, joinpath(output_images_path, "mnist_eccco.png")) ``` #### Energy Delta (not in paper) ```{julia} plt, gen_delta, ces = _plot_eccco_mnist(λ = [0.1,0.1,3.0], use_energy_delta=true) display(plt) savefig(plt, joinpath(output_images_path, "mnist_eccco_energy_delta.png")) ``` ```{julia} λ_delta = [0.1,0.1,2.5] λ = [0.1,0.25,0.25] plts = [] for i in 0:9 plt, _, _ = _plot_eccco_mnist(x_factual, i; λ = λ, plot_title="Distance") plt_delta, _, _ = _plot_eccco_mnist(x_factual, i; λ = λ_delta, use_energy_delta=true, plot_title="Energy Delta") plt = Plots.plot(plt, plt_delta; size=(img_height*2,img_height), layout=(1,2)) display(plt) push!(plts, plt) end ``` #### Additional Models (not in paper) LeNet-5: ```{julia} mutable struct LeNetBuilder filter_size::Int channels1::Int channels2::Int end preproc(X) = reshape(X, (28, 28, 1, :)) function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out) _n_in = Int(sqrt(n_in)) k, c1, c2 = b.filter_size, b.channels1, b.channels2 mod(k, 2) == 1 || error("`filter_size` must be odd. ") # padding to preserve image size on convolution: p = div(k - 1, 2) preproc(x) = reshape(x, (_n_in, _n_in, 1, :)) front = Flux.Chain( Conv((k, k), 1 => c1, pad=(p, p), relu), MaxPool((2, 2)), Conv((k, k), c1 => c2, pad=(p, p), relu), MaxPool((2, 2)), Flux.flatten ) d = Flux.outputsize(front, (_n_in, _n_in, 1, 1)) |> first back = Flux.Chain( Dense(d, 120, relu), Dense(120, 84, relu), Dense(84, n_out), ) chain = Flux.Chain(preproc, front, back) return chain end # Final model: lenet = NeuralNetworkClassifier( builder=LeNetBuilder(5, 6, 16), epochs=50, batch_size=batch_size, finaliser=_finaliser, loss=_loss, ) ``` Robust Neural network: ```{julia} mutable struct RobustNetBuilder n_hidden::Int lipschitz_bound::Float32 end function MLJFlux.build(b::RobustNetBuilder, rng, n_in, n_out) n_hidden, γ = b.n_hidden, b.lipschitz_bound _n_hidden = fill(n_hidden,2) model_ps = DenseLBDNParams{Float32}(n_in, _n_hidden, n_out, γ; rng) chain = Flux.Chain(DiffLBDN(model_ps)) return chain end # Final model: rob_net = NeuralNetworkClassifier( builder=RobustNetBuilder(60, 5.0), epochs=600, batch_size=batch_size, finaliser=_finaliser, loss=_loss, ) ``` Training all of them: ```{julia} add_retrain = true # Deep Ensemble: mlp_large_ens = EnsembleModel(model=mlp, n=50) # CNN Ensemble: lenet_ens = EnsembleModel(model=lenet, n=5) add_models = Dict( "LeNet-5" => lenet, # "LeNet-5 Ensemble" => lenet_ens, # "RobustNet" => rob_net, "Large Ensemble (n=50)" => mlp_large_ens, ) if add_retrain add_model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in add_models) large_model_dict = merge(model_dict, add_model_dict) Serialization.serialize(joinpath(output_path,"mnist_models_large.jls"), large_model_dict) else large_model_dict = Serialization.deserialize(joinpath(output_path,"mnist_models_large.jls")) end ``` ```{julia} # Evaluate models: measure = Dict( :f1score => multiclass_f1score, :acc => accuracy, :precision => multiclass_precision ) model_performance = DataFrame() for (mod_name, mod) in large_model_dict # Test performance: test_data = load_mnist_test() _perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure))) _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) _perf.mod_name .= mod_name _perf.dataname .= "MNIST" model_performance = vcat(model_performance, _perf) end Serialization.serialize(joinpath(output_path,"mnist_model_performance_additional.jls"), model_performance) CSV.write(joinpath(output_path, "mnist_model_performance_additional.csv"), model_performance) model_performance ``` ```{julia} _plt_order = [ "MLP", "MLP Ensemble", "Large Ensemble (n=50)", "LeNet-5", # "LeNet-5 Ensemble", # "RobustNet", "JEM", "JEM Ensemble", ] plt_additional_models, _, _ces_ = _plot_eccco_mnist( λ = [0.1,0.21,0.21], plt_order = _plt_order, model_dict=large_model_dict, wide = true, plot_factual = true, img_height = 150, ) display(plt_additional_models) # savefig(plt_additional_models, joinpath(output_images_path, "mnist_eccco_additional.png")) ``` ```{julia} Random.seed!(123) λ = [0.1,0.22,0.22] wachter = WachterGenerator( λ=λ[1], opt=eccco_generator.opt ) combos = [ (9,7), (7,2), (6,0), (4,1), (2,3), (5,8), (6,5), (1,7), (1,4), (0,3) ] n_each = 3 combos = reduce(vcat, [fill(c, n_each) for c in combos]) plts_eccco = [] plts_wachter = [] ces_eccco = [] ces_wachter = [] rngs = [] for (factual, target) in combos rng = rand(1:10000) # ECCCo: plt, _, ces = _plot_eccco_mnist( factual, target; λ = λ, plt_order = _plt_order, model_dict = large_model_dict, wide = true, plot_factual = true, rng = rng, img_height = 150 ) display(plt) push!(plts_eccco, plt) push!(ces_eccco, reduce(hcat, target_probs.(values(ces)))) # Wachter: plt, _, ces = _plot_eccco_mnist( factual, target; plt_order = _plt_order, model_dict = large_model_dict, wide = true, plot_factual = true, rng = rng, img_height = 150, generator = wachter, ) display(plt) push!(plts_wachter, plt) push!(ces_wachter, reduce(hcat, target_probs.(values(ces)))) push!(rngs, rng) end ``` ```{julia} final_plts = [] for (i, (factual, target)) in enumerate(combos) p1 = plts_eccco[i] p2 = plts_wachter[i] plt = Plots.plot( p1, p2, layout=(2,1), size = (1200, 400), ) display(plt) savefig(plt, joinpath("dev/rebuttal/www", "mnist_$(factual)to$(target)_$(i).png")) push!(final_plts, plt) end ``` ### All digits ```{julia} function plot_mnist( factual::Int,target::Int; generator::AbstractGenerator, model::AbstractFittedModel=model_dict["JEM Ensemble"], data::CounterfactualData=counterfactual_data, rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG, _plot_title::Bool=true, show_factual::Bool=false, img_height::Int=180, kwargs..., ) Random.seed!(rng) decision_threshold = !isdefined(kwargs, :decision_threshold) ? 0.9 : decision_threshold max_iter = !isdefined(kwargs, :max_iter) ? 100 : max_iter initialization = !isdefined(kwargs, :initialization) ? :identity : initialization converge_when = !isdefined(kwargs, :converge_when) ? :generator_conditions : converge_when x = reshape(data.X[:,rand(findall(predict_label(model, data).==factual))],input_dim,1) ce = generate_counterfactual( x, target, data, model, generator; decision_threshold=decision_threshold, max_iter=max_iter, initialization=initialization, converge_when=converge_when, kwargs... ) _title = _plot_title ? "$(factual) -> $(target)" : "" _x = CounterfactualExplanations.counterfactual(ce) plt = Plots.plot( convert2image(MNIST, reshape(_x,28,28)), axis=([], false), size=(img_height, img_height), title=_title ) if show_factual plt_factual = Plots.plot( convert2image(MNIST, reshape(x,28,28)), axis=([], false), size=(img_height, img_height), title="Factual" ) plt = Plots.plot(plt_factual, plt; size=(img_height*2,img_height), layout=(1,2)) end return plt end ``` ```{julia} _regen_all_digits = false if _regen_all_digits function plot_all_digits(rng=123;verbose=true,img_height=180,kwargs...) plts = [] for i in 0:9 for j in 0:9 @info "Generating counterfactual for $(i) -> $(j)" plt = plot_mnist(i,j;kwargs...,rng=rng, img_height=img_height) !verbose || display(plt) plts = [plts..., plt] end end plt = Plots.plot(plts...; size=(img_height*10,img_height*10), layout=(10,10), dpi=300) return plt end plt = plot_all_digits(generator=eccco_generator) savefig(plt, joinpath(output_images_path, "mnist_eccco_all_digits.png")) end ``` #### Energy Delta (not in paper) ```{julia} _regen_all_digits = true if _regen_all_digits function plot_all_digits(rng=123;verbose=true,img_height=180,kwargs...) plts = [] for i in 0:9 for j in 0:9 @info "Generating counterfactual for $(i) -> $(j)" plt = plot_mnist(i,j;kwargs...,rng=rng, img_height=img_height) !verbose || display(plt) plts = [plts..., plt] end end plt = Plots.plot(plts...; size=(img_height*10,img_height*10), layout=(10,10), dpi=300) return plt end plt = plot_all_digits(generator=gen_delta) savefig(plt, joinpath(output_images_path, "mnist_eccco_all_digits-delta.png")) end ``` ## Benchmark ```{julia} Λ = eccco_generator.λ # Benchmark generators: generator_dict = Dict( "Wachter" => WachterGenerator(λ=Λ[1], opt=eccco_generator.opt), "REVISE" => REVISEGenerator(λ=Λ[1], opt=eccco_generator.opt), "Schut" => greedy_generator, "ECCCo" => eccco_generator, ) ``` ```{julia} generator_params = DataFrame( Dict( :λ1 => Λ[1], :λ2 => Λ[2], :λ3 => Λ[3], :opt => string(typeof(eccco_generator.opt)), :eta => eccco_generator.opt.eta, :dataname => "MNIST", ) ) CSV.write(joinpath(params_path, "generator/mnist.csv"), generator_params) ``` ```{julia} # Measures: measures = [ CounterfactualExplanations.distance, ECCCo.distance_from_energy, ECCCo.distance_from_targets, CounterfactualExplanations.Evaluation.validity, CounterfactualExplanations.Evaluation.redundancy, ECCCo.set_size_penalty, ] bmks = [] for target in sort(unique(labels)) for factual in sort(unique(labels)) if factual == target continue end bmk = benchmark( counterfactual_data; models=model_dict, generators=generator_dict, measure=measures, suppress_training=true, dataname="MNIST", n_individuals=5, target=target, factual=factual, initialization=:identity, converge_when=:generator_conditions, ) push!(bmks, bmk) end end bmk = reduce(vcat, bmks) CSV.write(joinpath(output_path, "mnist_benchmark.csv"), bmk()) ``` ```{julia} df = @chain bmk() begin @mutate(variable = ifelse.(variable .== "distance_from_energy", "Non-Conformity", variable)) @mutate(variable = ifelse.(variable .== "distance_from_targets", "Implausibility", variable)) @mutate(variable = ifelse.(variable .== "distance", "Cost", variable)) @mutate(variable = ifelse.(variable .== "redundancy", "Redundancy", variable)) @mutate(variable = ifelse.(variable .== "Validity", "Validity", variable)) end plt = AlgebraOfGraphics.data(df) * visual(BoxPlot) * mapping(:generator, :value, row=:variable, col=:model, color=:generator) plt = draw( plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=150, height=120), facet=(; linkyaxes=:minimal) ) display(plt) save(joinpath(output_images_path, "mnist_benchmark.png"), plt, px_per_unit=5) ```