```{julia} include("$(pwd())/notebooks/setup.jl") eval(setup_notebooks) ``` ```{julia} # Counteractual data: n_obs = 1000 counterfactual_data = load_blobs(n_obs; cluster_std=0.1, center_box=(-1. => 1.)) X = counterfactual_data.X y = counterfactual_data.y labels = counterfactual_data.output_encoder.labels input_dim, nobs = size(X) batch_size = Int(round(nobs/10)) epochs = 100 ``` ```{julia} Plots.plot() display(Plots.scatter!(counterfactual_data)) ``` ```{julia} 𝒟x = Normal() 𝒟y = Categorical(ones(2) ./ 2) sampler = ConditionalSampler(𝒟x, 𝒟y, input_size=size(X)[1:end-1], batch_size=50) n_hidden = 16 clf = JointEnergyClassifier( sampler; builder=MLJFlux.MLP( hidden=(n_hidden, n_hidden, n_hidden), σ=Flux.swish ), batch_size=batch_size, finaliser=Flux.softmax, loss=Flux.Losses.crossentropy, jem_training_params=( α=[1.0,1.0,1e-1], verbosity=10, ), epochs=epochs, sampling_steps=30, ) ``` ```{julia} method = :simple_inductive cov = .95 conf_model = conformal_model(clf; method=method, coverage=cov) mach = machine(conf_model, table(permutedims(X)), labels) fit!(mach) Serialization.serialize(joinpath(output_path,"poc_model.jls"), mach) ``` ```{julia} #| echo: false niter = 1000 jem = mach.model.model.jem batch_size = mach.model.model.batch_size X = Float32.(matrix(X)) if typeof(jem.sampler) <: ConditionalSampler plts = [] for target in 1:2 X̂ = generate_conditional_samples(jem, batch_size, target; niter=niter) ex = extrema(hcat(X,X̂), dims=2) xlims = ex[1] ylims = ex[2] x1 = range(1.0f0.*xlims...,length=100) x2 = range(1.0f0.*ylims...,length=100) plt = Plots.contour( x1, x2, (x, y) -> softmax(jem([x, y]))[target], fill=true, alpha=0.5, title="Target: $target", cbar=true, xlims=xlims, ylims=ylims, ) Plots.scatter!(X[1,:], X[2,:], color=Int.(labels.refs), group=Int.(labels.refs), alpha=0.5) Plots.scatter!( X̂[1,:], X̂[2,:], color=repeat([target], size(X̂,2)), group=repeat([target], size(X̂,2)), shape=:star5, ms=10 ) push!(plts, plt) end plt = Plots.plot(plts..., layout=(1, 2), size=(2*500, 400)) display(plt) end ``` ```{julia} #| output: true #| echo: false #| label: fig-losses #| fig-cap: "Illustration of the smooth size loss and the configurable classification loss." temp = 0.25 p1 = Plots.contourf(mach.model, mach.fitresult, permutedims(X), labels; plot_set_loss=true, zoom=0, temp=temp) p2 = Plots.contourf(mach.model, mach.fitresult, permutedims(X), labels; plot_classification_loss=true, zoom=0, temp=temp, clim=nothing, loss_matrix=ones(2,2)) display(Plots.plot(p1, p2, size=(800,320))) ``` ```{julia} Random.seed!(1234) λ₁ = 0.1 λ₂ = 0.4 λ₃ = 0.4 Λ = [λ₁, λ₂, λ₃] M = ECCCo.ConformalModel(mach.model, mach.fitresult) factual_label = levels(labels)[2] x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) target = levels(labels)[1] factual = predict_label(M, counterfactual_data, x_factual)[1] opt = Flux.Optimise.Descent(0.01) generator_dict = OrderedDict( "Wachter" => WachterGenerator(λ = λ₁, opt=opt), "Schut" => GreedyGenerator(λ = λ₁), "REVISE" => REVISEGenerator(λ = λ₁, opt=opt), "ECCCo" => ECCCoGenerator(λ = Λ, opt=opt), ) ces = Dict{Any,Any}() plts = [] for (name, generator) in generator_dict ce = generate_counterfactual( x_factual, target, counterfactual_data, M, generator; initialization=:identity, converge_when=:generator_conditions, ) plt = Plots.plot( ce, title=name, alpha=0.2, cbar=false, axis=nothing, length_out=10, contour_alpha=1.0, ) if name == "ECCCo" _X = distance_from_energy(ce, return_conditionals=true) Plots.scatter!( _X[1,:],_X[2,:], color=:purple, shape=:star5, ms=10, label="x̂|$target", alpha=0.1 ) end push!(plts, plt) ces[name] = ce end plt = Plots.plot(plts..., size=(500,520)) display(plt) savefig(plt, joinpath(output_images_path, "poc.png")) ``` ```{julia} using Colors col_pal = palette(:seaborn_colorblind) Random.seed!(1234) using CounterfactualExplanations.Generators: ∇ λ₁ = 0.1 λ₂ = 1.0 λ₃ = 0.5 Λ = [λ₁, λ₂, λ₃] η = 0.01 M = ECCCo.ConformalModel(mach.model, mach.fitresult) factual_label = levels(labels)[2] x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) target = levels(labels)[1] factual = predict_label(M, counterfactual_data, x_factual)[1] opt = Flux.Optimise.Descent(η) generator_dict = OrderedDict( "Wachter" => WachterGenerator(λ = 0.3, opt=opt), "ECCCo (no EBM)" => ECCCoGenerator(λ = [λ₁,λ₂,0.0], opt=opt), "ECCCo (no CP)" => ECCCoGenerator(λ = [λ₁,0.0,λ₃], opt=opt), "ECCCo" => ECCCoGenerator(λ = Λ, opt=opt), ) # Gradient field: function loss_grads(generator, model, ce, x) x = Float32.(x) _ce = deepcopy(ce) _ce.s′ = x return ∇(generator,M,_ce) end meshgrid(x, y) = (repeat(x, outer=length(y)), repeat(y, inner=length(x))) xlims, ylims = extrema(X, dims=2) xrange = range(xlims..., length=10) yrange = range(ylims..., length=10) x1, x2 = meshgrid(xrange, yrange) inputs = zip(x1, x2) # Plot: ces = Dict{Any,Any}() plts = [] for (name, generator) in generator_dict # CE: ce = generate_counterfactual( x_factual, target, counterfactual_data, M, generator; initialization=:identity, converge_when=:generator_conditions, ) # Main plot (path): plt = Plots.plot( ce, title=name, alpha=0.1, cbar=false, axis=nothing, length_out=10, contour_alpha=0.0, legend = false, palette = col_pal, ) # Generated samples: if name ∈ ["ECCCo","ECCCo (no CP)"] _X = distance_from_energy(ce, return_conditionals=true) Plots.scatter!( _X[1,:],_X[2,:], color=col_pal[end-1], shape=:star5, ms=10, label="x̂|$target", alpha=0.1 ) end # Gradient field: u = [] v = [] for (x, y) in inputs g = -loss_grads(generator, M, ce, [x, y][:,:]) push!(u, η * g[1]) push!(v, η * g[2]) end Plots.quiver!(x1, x2, quiver=(u, v), color=col_pal[5]) push!(plts, plt) ces[name] = ce end plt = Plots.plot(plts..., size=(500,520)) display(plt) savefig(plt, joinpath(output_images_path, "poc_gradient_fields.png")) ``` ```{julia} ce = ces["ECCCo"] using CounterfactualExplanations.Generators: ℓ # loss function f1(x) _ce = deepcopy(ce) _ce.s′ = x return - ℓ(_ce.generator, _ce) end # cost function f2(x) _ce = deepcopy(ce) _ce.s′ = x λ = _ce.generator.λ[1] _loss = _ce.generator.penalty[1] return - λ * _loss(_ce) end # set size function f3(x) _ce = deepcopy(ce) _ce.s′ = x λ = _ce.generator.λ[2] _loss = _ce.generator.penalty[2] return - λ * _loss(_ce) end # distance from energy function f4(x) _ce = deepcopy(ce) _ce.s′ = x λ = _ce.generator.λ[3] _loss = _ce.generator.penalty[3] return - λ * _loss(_ce) end # Helper function: meshgrid(x, y) = (repeat(x, outer=length(y)), repeat(y, inner=length(x))) xlims, ylims = extrema(X, dims=2) xrange = range(xlims..., length=10) yrange = range(ylims..., length=10) x1, x2 = meshgrid(xrange, yrange) inputs = zip(x1, x2) u = [] v = [] for (x, y) in inputs push!(u, scale * gradient(f1, [x, y][:,:])[1][1]) push!(v, scale * gradient(f1, [x, y][:,:])[1][2]) end Plots.plot(xlims=xlims, ylims=ylims) Plots.scatter!(counterfactual_data) Plots.quiver!(x1, x2, quiver=(u, v)) ```