```{julia} include("$(pwd())/notebooks/setup.jl") eval(setup_notebooks) ``` # Linearly Separable Data ```{julia} # Hyper: _retrain = true # Data: test_size = 0.2 n_obs = Int(1000 / (1.0 - test_size)) counterfactual_data, test_data = train_test_split( load_blobs(n_obs; cluster_std=0.1, center_box=(-1. => 1.)); test_size=test_size ) X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) X = table(permutedims(X)) labels = counterfactual_data.output_encoder.labels input_dim, n_obs = size(counterfactual_data.X) output_dim = length(unique(labels)) ``` First, let's create a couple of image classifier architectures: ```{julia} # Model parameters: epochs = 100 batch_size = minimum([Int(round(n_obs/10)), 128]) n_hidden = 16 activation = Flux.swish builder = MLJFlux.MLP( hidden=(n_hidden, n_hidden, n_hidden), σ=Flux.swish ) n_ens = 5 # number of models in ensemble _loss = Flux.Losses.crossentropy # loss function _finaliser = Flux.softmax # finaliser function ``` ```{julia} # JEM parameters: 𝒟x = Normal() 𝒟y = Categorical(ones(output_dim) ./ output_dim) sampler = ConditionalSampler( 𝒟x, 𝒟y, input_size=(input_dim,), batch_size=50, ) α = [1.0,1.0,1e-1] # penalty strengths ``` ```{julia} # Simple MLP: mlp = NeuralNetworkClassifier( builder=builder, epochs=epochs, batch_size=batch_size, finaliser=_finaliser, loss=_loss, ) # Deep Ensemble: mlp_ens = EnsembleModel(model=mlp, n=n_ens) # Joint Energy Model: jem = JointEnergyClassifier( sampler; builder=builder, epochs=epochs, batch_size=batch_size, finaliser=_finaliser, loss=_loss, jem_training_params=( α=α,verbosity=10, ), sampling_steps=30, ) # JEM with adversarial training: jem_adv = deepcopy(jem) # jem_adv.adv_training = true # Deep Ensemble of Joint Energy Models: jem_ens = EnsembleModel(model=jem, n=n_ens) # Deep Ensemble of Joint Energy Models with adversarial training: # jem_ens_plus = EnsembleModel(model=jem_adv, n=n_ens) # Dictionary of models: models = Dict( "MLP" => mlp, # "MLP Ensemble" => mlp_ens, "JEM" => jem, # "JEM Ensemble" => jem_ens, # "JEM Ensemble+" => jem_ens_plus, ) ``` ```{julia} # Train models: function _train(model, X=X, y=labels; cov=0.95, method=:simple_inductive, mod_name="model") conf_model = conformal_model(model; method=method, coverage=cov) mach = machine(conf_model, X, y) @info "Begin training $mod_name." fit!(mach) @info "Finished training $mod_name." M = ECCCo.ConformalModel(mach.model, mach.fitresult) return M end if _retrain model_dict = Dict(mod_name => _train(model; mod_name=mod_name) for (mod_name, model) in models) Serialization.serialize(joinpath(output_path,"linearly_separable_models.jls"), model_dict) else model_dict = Serialization.deserialize(joinpath(output_path,"linearly_separable_models.jls")) end ``` ```{julia} # Evaluate models: measure = Dict( :f1score => multiclass_f1score, :acc => accuracy, :precision => multiclass_precision ) model_performance = DataFrame() for (mod_name, model) in model_dict # Test performance: _perf = CounterfactualExplanations.Models.model_evaluation(model, test_data, measure=collect(values(measure))) _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) _perf.mod_name .= mod_name model_performance = vcat(model_performance, _perf) end Serialization.serialize(joinpath(output_path,"linearly_separable_model_performance.jls"), model_performance) CSV.write(joinpath(output_path, "linearly_separable_model_performance.csv"), model_performance) model_performance ``` ```{julia} n_regen = 1000 n_each = batch_size for (mod_name, model) in model_dict K = length(counterfactual_data.y_levels) input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1)) 𝒟x = Uniform(extrema(counterfactual_data.X)...) 𝒟y = Categorical(ones(K) ./ K) sampler = ConditionalSampler(𝒟x, 𝒟y; input_size=input_size) opt = ImproperSGLD() plts = [] for target in levels(labels) target_idx = findall(levels(labels) .== target)[1] f(x) = logits(model, x) X̂ = sampler(f, opt; niter=n_regen, n_samples=n_each, y=target_idx) ex = extrema(hcat(MLJFlux.reformat(X),X̂), dims=2) xlims = ex[1] ylims = ex[2] x1 = range(1.0f0.*xlims...,length=10) x2 = range(1.0f0.*ylims...,length=10) p(x) = probs(model, x) plt = Plots.contour( x1, x2, (x, y) -> p([x, y][:,:])[target_idx], fill=true, alpha=0.5, title="Target: $target", cbar=true, xlims=xlims, ylims=ylims, ) Plots.scatter!( MLJFlux.reformat(X)[1,:], MLJFlux.reformat(X)[2,:], color=Int.(labels.refs), group=Int.(labels.refs), alpha=0.5 ) Plots.scatter!( X̂[1,:], X̂[2,:], color=repeat([target_idx], size(X̂,2)), group=repeat([target_idx], size(X̂,2)), shape=:star5, ms=10 ) savefig(plt, joinpath(output_images_path, "linearly_separable_generated_$(mod_name).png")) push!(plts, plt) end plt = Plots.plot(plts..., layout=(1, 2), size=(2*500, 400), plot_title=mod_name) display(plt) end ``` ```{julia} #| output: true #| echo: false #| label: fig-losses #| fig-cap: "Illustration of the smooth size loss and the configurable classification loss." X_plot = matrix(X) temp = 0.1 for (mod_name, model) in model_dict p0 = Plots.contourf(model.model, model.fitresult, X_plot, labels; plot_set_size=true, zoom=0, temp=temp) p1 = Plots.contourf(model.model, model.fitresult, X_plot, labels; plot_set_loss=true, zoom=0, temp=temp) p2 = Plots.contourf(model.model, model.fitresult, X_plot, labels; plot_classification_loss=true, zoom=0, temp=temp, clim=nothing, loss_matrix=ones(2,2)) display(Plots.plot(p0, p1, p2, size=(1400,320), plot_title=mod_name, layout=(1,3))) end ``` ## Benchmark ```{julia} λ₁ = 0.25 λ₂ = 0.75 λ₃ = 0.75 Λ = [λ₁, λ₂, λ₃] opt = Flux.Optimise.Descent(0.01) use_class_loss = false # Benchmark generators: generator_dict = Dict( "Wachter" => WachterGenerator(λ=λ₁, opt=opt), "REVISE" => REVISEGenerator(λ=λ₁, opt=opt), "Schut" => GreedyGenerator(), "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss), "ECCCo (no CP)" => ECCCoGenerator(λ=[λ₁, 0.0, λ₃], opt=opt, use_class_loss=use_class_loss), "ECCCo (no EBM)" => ECCCoGenerator(λ=[λ₁, λ₂, 0.0], opt=opt, use_class_loss=use_class_loss), ) ``` ### POC ```{julia} Random.seed!(2023) M = model_dict["JEM"] X = X isa Matrix ? X : Float32.(permutedims(matrix(X))) factual_label = levels(labels)[1] x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) target = levels(labels)[2] factual = predict_label(M, counterfactual_data, x_factual)[1] ces = Dict{Any,Any}() plts = [] for (name, generator) in generator_dict ce = generate_counterfactual( x_factual, target, counterfactual_data, M, generator; initialization=:identity, converge_when=:generator_conditions, ) plt = Plots.plot( ce, title=name, alpha=0.2, cbar=false, # axis=nothing, ) if contains(name, "ECCCo") _X = distance_from_energy(ce, return_conditionals=true) Plots.scatter!( _X[1,:],_X[2,:], color=:purple, shape=:star5, ms=10, label="x̂|$target", alpha=0.5 ) end push!(plts, plt) ces[name] = ce end plt = Plots.plot(plts..., size=(650,500)) display(plt) savefig(plt, joinpath(output_images_path, "linearly_separable_poc.png")) ``` ### Complete Benchmark ```{julia} # Measures: measures = [ CounterfactualExplanations.distance, ECCCo.distance_from_energy, ECCCo.distance_from_targets, CounterfactualExplanations.Evaluation.validity, CounterfactualExplanations.Evaluation.redundancy, ECCCo.set_size_penalty ] bmks = [] for target in sort(unique(labels)) for factual in sort(unique(labels)) if factual == target continue end bmk = benchmark( counterfactual_data; models=model_dict, generators=generator_dict, measure=measures, suppress_training=true, dataname="Linearly Separable", n_individuals=25, target=target, factual=factual, initialization=:identity, converge_when=:generator_conditions, ) push!(bmks, bmk) end end bmk = reduce(vcat, bmks) CSV.write(joinpath(output_path, "linearly_separable_benchmark.csv"), bmk()) ``` ```{julia} df = @chain bmk() begin @mutate(variable = ifelse.(variable .== "distance_from_energy", "Non-Conformity", variable)) @mutate(variable = ifelse.(variable .== "distance_from_targets", "Implausibility", variable)) @mutate(variable = ifelse.(variable .== "distance", "Cost", variable)) @mutate(variable = ifelse.(variable .== "redundancy", "Redundancy", variable)) @mutate(variable = ifelse.(variable .== "Validity", "Validity", variable)) end plt = AlgebraOfGraphics.data(df) * visual(BoxPlot) * mapping(:generator, :value, row=:variable, col=:model, color=:generator) plt = draw( plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=150, height=120), facet=(; linkyaxes=:none) ) display(plt) save(joinpath(output_images_path, "linearly_separable_benchmark.png"), plt, px_per_unit=5) ```