```{julia} include("$(pwd())/notebooks/setup.jl") eval(setup_notebooks) ``` # FashionMNIST ## Anecdotal Evidence ### Examples in Introduction #### Wachter and JSMA ```{julia} Random.seed!(2023) # Data: counterfactual_data = load_fashion_mnist() X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) input_dim, n_obs = size(counterfactual_data.X) M = load_fashion_mnist_mlp() # Target: factual_label = 9 x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) target = 7 factual = predict_label(M, counterfactual_data, x_factual)[1] γ = 0.9 # Training params: T = 100 ``` ```{julia} # Search: generic_generator = WachterGenerator() ce_wachter = generate_counterfactual( x_factual, target, counterfactual_data, M, generic_generator; decision_threshold=γ, max_iter=T, initialization=:identity, ) greedy_generator = GreedyGenerator(η=2.0) ce_jsma = generate_counterfactual( x_factual, target, counterfactual_data, M, greedy_generator; decision_threshold=γ, max_iter=T, initialization=:identity, ) ``` ```{julia} p1 = Plots.plot( convert2image(FashionMNIST, reshape(x_factual,28,28)), axis=nothing, size=(img_height, img_height), title="Factual" ) plts = [p1] ces = zip([ce_wachter,ce_jsma]) counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces)) phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces)) for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3), ["Wachter","JSMA"]) ce, _phat, _name = (x[1],x[2],x[3]) _title = "$(_name) (p=$(round(_phat[1]; digits=2)))" plt = Plots.plot( convert2image(FashionMNIST, reshape(ce,28,28)), axis=nothing, size=(img_height, img_height), title=_title ) plts = [plts..., plt] end plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) display(plt) savefig(plt, joinpath(output_images_path, "fashion_you_may_not_like_it.png")) ``` #### REVISE ```{julia} using CounterfactualExplanations.Models: load_fashion_mnist_vae vae = load_fashion_mnist_vae() vae_weak = load_fashion_mnist_vae(;strong=false) Serialization.serialize(joinpath(output_path,"fashion_mnist_classifier.jls"), M) Serialization.serialize(joinpath(output_path,"fashion_mnist_vae.jls"), vae) Serialization.serialize(joinpath(output_path,"fashion_mnist_vae_weak.jls"), vae_weak) ``` ```{julia} # Define generator: revise_generator = REVISEGenerator( opt = Flux.Optimise.Descent(0.25), λ=0.0, ) # Generate recourse: counterfactual_data.generative_model = vae # assign generative model ce_strong = generate_counterfactual( x_factual, target, counterfactual_data, M, revise_generator; decision_threshold=γ, max_iter=T, initialization=:identity, converge_when=:generator_conditions, ) counterfactual_data_weak = deepcopy(counterfactual_data) counterfactual_data_weak.generative_model = vae_weak ce_weak = generate_counterfactual( x_factual, target, counterfactual_data_weak, M, revise_generator; decision_threshold=γ, max_iter=T, initialization=:identity, converge_when=:generator_conditions, ) ``` ```{julia} ces = zip([ce_strong,ce_weak]) counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces)) phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces)) plts = [p1] for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3), ["Strong VAE","Weak VAE"]) ce, _phat, _name = (x[1],x[2],x[3]) _title = "$(_name) (p=$(round(_phat[1]; digits=2)))" plt = Plots.plot( convert2image(FashionMNIST, reshape(ce,28,28)), axis=nothing, size=(img_height, img_height), title=_title ) plts = [plts..., plt] end plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) display(plt) savefig(plt, joinpath(output_images_path, "fashion_surrogate_gone_wrong.png")) ``` ### ECCCo ```{julia} function pre_process(x; noise::Float32=0.03f0) ϵ = Float32.(randn(size(x)) * noise) x += ϵ return x end ``` ```{julia} # Hyper: _retrain = true _regen = true # Data: n_obs = 10000 counterfactual_data = load_fashion_mnist(n_obs) counterfactual_data.X = pre_process.(counterfactual_data.X) counterfactual_data.generative_model = vae X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) X = table(permutedims(X)) x_factual = reshape(pre_process(x_factual, noise=0.0f0), input_dim, 1) labels = counterfactual_data.output_encoder.labels input_dim, n_obs = size(counterfactual_data.X) n_digits = Int(sqrt(input_dim)) output_dim = length(unique(labels)) ``` First, let's create a couple of image classifier architectures: ```{julia} # Model parameters: epochs = 100 batch_size = minimum([Int(round(n_obs/10)), 128]) n_hidden = 128 activation = Flux.swish builder = MLJFlux.@builder Flux.Chain( Dense(n_in, n_hidden, activation), Dense(n_hidden, n_out), ) n_ens = 5 # number of models in ensemble _loss = Flux.Losses.logitcrossentropy # loss function _finaliser = x -> x # finaliser function ``` ```{julia} # JEM parameters: 𝒟x = Uniform(-1.0,1.0) 𝒟y = Categorical(ones(output_dim) ./ output_dim) sampler = ConditionalSampler( 𝒟x, 𝒟y, input_size=(input_dim,), batch_size=10, ) α = [1.0,1.0,2e-2] # penalty strengths ``` ```{julia} # Simple MLP: mlp = NeuralNetworkClassifier( builder=builder, epochs=epochs, batch_size=batch_size, finaliser=_finaliser, loss=_loss, ) # Deep Ensemble: mlp_ens = EnsembleModel(model=mlp, n=n_ens) # Joint Energy Model: jem = JointEnergyClassifier( sampler; builder=builder, epochs=epochs, batch_size=batch_size, finaliser=_finaliser, loss=_loss, jem_training_params=( α=α,verbosity=10, ), sampling_steps=25, ) # JEM with adversarial training: jem_adv = deepcopy(jem) # jem_adv.adv_training = true # Deep Ensemble of Joint Energy Models: jem_ens = EnsembleModel(model=jem, n=n_ens) # Deep Ensemble of Joint Energy Models with adversarial training: # jem_ens_plus = EnsembleModel(model=jem_adv, n=n_ens) # Dictionary of models: models = Dict( "MLP" => mlp, "MLP Ensemble" => mlp_ens, "JEM" => jem, "JEM Ensemble" => jem_ens, # "JEM Ensemble+" => jem_ens_plus, ) Serialization.serialize(joinpath(output_path,"fashion_mnist_architectures.jls"), models) ``` ```{julia} # Train models: function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_name="model") conf_model = conformal_model(model; method=method, coverage=cov) mach = machine(conf_model, X, y) @info "Begin training $mod_name." fit!(mach) @info "Finished training $mod_name." M = ECCCo.ConformalModel(mach.model, mach.fitresult) return M end if _retrain model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in models) Serialization.serialize(joinpath(output_path,"fashion_mnist_models.jls"), model_dict) else model_dict = Serialization.deserialize(joinpath(output_path,"fashion_mnist_models.jls")) end ``` ```{julia} # Plot generated samples: n_regen = 500 if _regen for (mod_name, mod) in model_dict K = length(counterfactual_data.y_levels) input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1)) 𝒟x = Uniform(extrema(counterfactual_data.X)...) 𝒟y = Categorical(ones(K) ./ K) sampler = ConditionalSampler(𝒟x, 𝒟y; input_size=input_size, prob_buffer=0.0) opt = ImproperSGLD() f(x) = logits(mod, x) _w = 1500 plts = [] neach = 10 for i in 1:10 x = sampler(f, opt; niter=n_regen, n_samples=neach, y=i) plts_i = [] for j in 1:size(x, 2) xj = x[:,j] xj = reshape(xj, (n_digits, n_digits)) plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)] end plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10)) plts = [plts..., plt] end plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1), plot_title=mod_name) savefig(plt, joinpath(output_images_path, "fashion_mnist_generated_$(mod_name).png")) display(plt) end end ``` ```{julia} # Evaluate models: measure = Dict( :f1score => multiclass_f1score, :acc => accuracy, :precision => multiclass_precision ) model_performance = DataFrame() for (mod_name, mod) in model_dict # Test performance: test_data = load_fashion_mnist_test() _perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure))) _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) _perf.mod_name .= mod_name model_performance = vcat(model_performance, _perf) end Serialization.serialize(joinpath(output_path,"fashion_mnist_model_performance.jls"), model_performance) CSV.write(joinpath(output_path, "fashion_mnist_model_performance.csv"), model_performance) model_performance ``` ### Different Models ```{julia} function _plot_eccco_fashion_mnist( x::Union{AbstractArray, Int}=x_factual, target::Int=target; λ=[0.1,0.25,0.25], temp=0.1, plt_order = ["MLP", "MLP Ensemble", "JEM", "JEM Ensemble"], opt = nothing, rng::Union{Int,AbstractRNG}=1234, T::Int = 100, ) # Setup: Random.seed!(rng) if x isa Int x = reshape(counterfactual_data.X[:,rand(findall(labels.==x))],input_dim,1) end # Generate counterfactuals using ECCCo generator: eccco_generator = ECCCoGenerator( λ=λ, temp=temp, opt=opt, ) ces = Dict() for (mod_name, mod) in model_dict ce = generate_counterfactual( x, target, counterfactual_data, mod, eccco_generator; decision_threshold=γ, max_iter=T, initialization=:identity, converge_when=:generator_conditions, ) ces[mod_name] = ce end _plt_order = map(x -> findall(collect(keys(model_dict)) .== x)[1], plt_order) # Plot: p1 = Plots.plot( convert2image(FashionMNIST, reshape(x,28,28)), axis=nothing, size=(img_height, img_height), title="Factual" ) plts = [] for (_name,ce) in ces _x = CounterfactualExplanations.counterfactual(ce) _phat = target_probs(ce) _title = "$_name (p̂=$(round(_phat[1]; digits=3)))" plt = Plots.plot( convert2image(FashionMNIST, reshape(_x,28,28)), axis=nothing, size=(img_height, img_height), title=_title ) plts = [plts..., plt] end plts = plts[_plt_order] plts = [p1, plts...] plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) return plt, eccco_generator end ``` ```{julia} plt, eccco_generator = _plot_eccco_fashion_mnist() display(plt) savefig(plt, joinpath(output_images_path, "fashion_mnist_eccco.png")) ``` ### All digits ```{julia} function plot_fashion_mnist( factual::Int,target::Int; generator::AbstractGenerator, model::AbstractFittedModel=model_dict["JEM Ensemble"], data::CounterfactualData=counterfactual_data, rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG, _plot_title::Bool=true, show_factual::Bool=false, kwargs..., ) Random.seed!(rng) decision_threshold = !isdefined(kwargs, :decision_threshold) ? 0.9 : decision_threshold max_iter = !isdefined(kwargs, :max_iter) ? 100 : max_iter initialization = !isdefined(kwargs, :initialization) ? :identity : initialization converge_when = !isdefined(kwargs, :converge_when) ? :generator_conditions : converge_when x = reshape(data.X[:,rand(findall(predict_label(model, data).==factual))],input_dim,1) ce = generate_counterfactual( x, target, data, model, generator; decision_threshold=decision_threshold, max_iter=max_iter, initialization=initialization, converge_when=converge_when, kwargs... ) _title = _plot_title ? "$(factual) -> $(target)" : "" _x = CounterfactualExplanations.counterfactual(ce) plt = Plots.plot( convert2image(FashionMNIST, reshape(_x,28,28)), axis=nothing, size=(img_height, img_height), title=_title ) if show_factual plt_factual = Plots.plot( convert2image(FashionMNIST, reshape(x,28,28)), axis=nothing, size=(img_height, img_height), title="Factual" ) plt = Plots.plot(plt_factual, plt; size=(img_height*2,img_height), layout=(1,2)) end return plt end ``` ```{julia} if _regen function plot_all_digits(rng=123;verbose=true,kwargs...) plts = [] for i in 0:9 for j in 0:9 @info "Generating counterfactual for $(i) -> $(j)" plt = plot_fashion_mnist(i,j;kwargs...,rng=rng) !verbose || display(plt) plts = [plts..., plt] end end plt = Plots.plot(plts...; size=(img_height*10,img_height*10), layout=(10,10)) return plt end plt = plot_all_digits(generator=eccco_generator) savefig(plt, joinpath(output_images_path, "fashion_mnist_eccco_all_digits.png")) end ``` ## Benchmark ```{julia} λ₁ = 0.1 λ₂ = 0.25 λ₃ = 0.25 Λ = [λ₁, λ₂, λ₃] # Benchmark generators: generator_dict = Dict( "Wachter" => WachterGenerator(λ=λ₁), "REVISE" => REVISEGenerator(λ=λ₁), "Schut" => GreedyGenerator(), "ECCCo" => ECCCoGenerator(λ=Λ), ) ``` ```{julia} # Measures: measures = [ CounterfactualExplanations.distance, ECCCo.distance_from_energy, ECCCo.distance_from_targets, CounterfactualExplanations.Evaluation.validity, CounterfactualExplanations.Evaluation.redundancy, ECCCo.set_size_penalty ] bmks = [] for target in sort(unique(labels)) for factual in sort(unique(labels)) if factual == target continue end bmk = benchmark( counterfactual_data; models=model_dict, generators=generator_dict, measure=measures, suppress_training=true, dataname="FashionMNIST", n_individuals=5, target=target, factual=factual, initialization=:identity, converge_when=:generator_conditions, ) push!(bmks, bmk) end end bmk = reduce(vcat, bmks) CSV.write(joinpath(output_path, "fashion_mnist_benchmark.csv"), bmk()) ``` ```{julia} df = @chain bmk() begin @mutate(variable = ifelse.(variable .== "distance_from_energy", "Non-Conformity", variable)) @mutate(variable = ifelse.(variable .== "distance_from_targets", "Implausibility", variable)) @mutate(variable = ifelse.(variable .== "distance", "Cost", variable)) @mutate(variable = ifelse.(variable .== "redundancy", "Redundancy", variable)) @mutate(variable = ifelse.(variable .== "Validity", "Validity", variable)) end plt = AlgebraOfGraphics.data(df) * visual(BoxPlot) * mapping(:generator, :value, row=:variable, col=:model, color=:generator) plt = draw( plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=150, height=120), facet=(; linkyaxes=:minimal) ) display(plt) save(joinpath(output_images_path, "fashion_mnist_benchmark.png"), plt, px_per_unit=5) ```