```{julia} include("notebooks/setup.jl") eval(setup_notebooks) ``` ```{julia} # Counteractual data: counterfactual_data = load_linearly_separable(1000) X = counterfactual_data.X y = counterfactual_data.y labels = counterfactual_data.output_encoder.labels input_dim, nobs = size(X) batch_size = Int(round(nobs/10)) epochs = 100 ``` ```{julia} Plots.plot() display(Plots.scatter!(counterfactual_data)) ``` ```{julia} 𝒟x = Normal() 𝒟y = Categorical(ones(2) ./ 2) sampler = ConditionalSampler(𝒟x, 𝒟y, input_size=size(X)[1:end-1], batch_size=50) n_hidden = 16 clf = JointEnergyClassifier( sampler; builder=MLJFlux.MLP( hidden=(n_hidden, n_hidden, n_hidden), σ=Flux.swish ), batch_size=batch_size, finaliser=x -> x, loss=Flux.Losses.logitcrossentropy, jem_training_params=( α=[1.0,1.0,1e-1], verbosity=10, ), epochs=epochs, sampling_steps=30, ) ``` ```{julia} method = :simple_inductive cov = .95 conf_model = conformal_model(clf; method=method, coverage=cov) mach = machine(conf_model, table(permutedims(X)), labels) fit!(mach) Serialization.serialize(joinpath(output_path,"poc_model.jls"), mach) ``` ```{julia} #| echo: false niter = 1000 jem = mach.model.model.jem batch_size = mach.model.model.batch_size X = Float32.(matrix(X)) if typeof(jem.sampler) <: ConditionalSampler plts = [] for target in 1:2 X̂ = generate_conditional_samples(jem, batch_size, target; niter=niter) ex = extrema(hcat(X,X̂), dims=2) xlims = ex[1] ylims = ex[2] x1 = range(1.0f0.*xlims...,length=100) x2 = range(1.0f0.*ylims...,length=100) plt = Plots.contour( x1, x2, (x, y) -> softmax(jem([x, y]))[target], fill=true, alpha=0.5, title="Target: $target", cbar=true, xlims=xlims, ylims=ylims, ) Plots.scatter!(X[1,:], X[2,:], color=Int.(labels.refs), group=Int.(labels.refs), alpha=0.5) Plots.scatter!( X̂[1,:], X̂[2,:], color=repeat([target], size(X̂,2)), group=repeat([target], size(X̂,2)), shape=:star5, ms=10 ) push!(plts, plt) end plt = Plots.plot(plts..., layout=(1, 2), size=(2*500, 400)) display(plt) end ``` ```{julia} Random.seed!(1234) λ₁ = 0.1 λ₂ = 0.5 λ₃ = 0.5 Λ = [λ₁, λ₂, λ₃] M = ECCCo.ConformalModel(mach.model, mach.fitresult) factual_label = levels(labels)[2] x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) target = levels(labels)[1] factual = predict_label(M, counterfactual_data, x_factual)[1] generator_dict = OrderedDict( "Wachter (γ=$γ)" => WachterGenerator(λ = λ₁), "Schut" => GreedyGenerator(λ = λ₁), "REVISE" => REVISEGenerator(λ = λ₁), "ECCCo" => ECCCoGenerator(λ = Λ), ) ces = Dict{Any,Any}() plts = [] for (name, generator) in generator_dict ce = generate_counterfactual( x_factual, target, counterfactual_data, M, generator; initialization=:identity, converge_when=:generator_conditions, decision_threshold=γ, ) plt = Plots.plot(ce, title=name, alpha=0.2, cbar=false, axis=nothing, length_out=10) if name == "ECCCo" _X = distance_from_energy(ce, return_conditionals=true) Plots.scatter!( _X[1,:],_X[2,:], color=:purple, shape=:star5, ms=10, label="x̂|$target", alpha=0.1 ) end push!(plts, plt) ces[name] = ce end plt = Plots.plot(plts..., size=(500,520)) display(plt) savefig(plt, joinpath(output_images_path, "poc.png")) ``` ```{julia} ce = ces["ECCCo"] using CounterfactualExplanations.Generators: ℓ # loss function f1(x) _ce = deepcopy(ce) _ce.s′ = x return - ℓ(_ce.generator, _ce) end # cost function f2(x) _ce = deepcopy(ce) _ce.s′ = x λ = _ce.generator.λ[1] _loss = _ce.generator.penalty[1] return - λ * _loss(_ce) end # set size function f3(x) _ce = deepcopy(ce) _ce.s′ = x λ = _ce.generator.λ[2] _loss = _ce.generator.penalty[2] return - λ * _loss(_ce) end # distance from energy function f4(x) _ce = deepcopy(ce) _ce.s′ = x λ = _ce.generator.λ[3] _loss = _ce.generator.penalty[3] return - λ * _loss(_ce) end # Helper function: meshgrid(x, y) = (repeat(x, outer=length(y)), repeat(y, inner=length(x))) xlims, ylims = extrema(X, dims=2) xrange = range(xlims..., length=10) yrange = range(ylims..., length=10) x1, x2 = meshgrid(xrange, yrange) inputs = zip(x, y) u = [] v = [] scale = 0.1 for (x, y) in inputs push!(u, scale * gradient(f2, [x, y][:,:])[1][1]) push!(v, scale * gradient(f2, [x, y][:,:])[1][2]) end Plots.plot(xlims=xlims, ylims=ylims) Plots.scatter!(counterfactual_data) Plots.quiver!(x1, x2, quiver=(u, v)) ```