diff --git a/notebooks/fidelity.qmd b/notebooks/fidelity.qmd index def8c2757a419438923f3367cd976c2ffaee6ca4..155bc957de66635ff7bbf61274e09d9eab79490b 100644 --- a/notebooks/fidelity.qmd +++ b/notebooks/fidelity.qmd @@ -15,10 +15,38 @@ using Plots # Fidelity Measures +## Binary + +```{julia} +# Setup +counterfactual_data = load_linearly_separable() +M = fit_model(counterfactual_data, :DeepEnsemble) +target = 2 +factual = 1 +chosen = rand(findall(predict_label(M, counterfactual_data) .== factual)) +x = select_factual(counterfactual_data, chosen) + +# Search: +generator = GenericGenerator(opt=Descent(0.01)) +ce = generate_counterfactual(x, target, counterfactual_data, M, generator) +``` + +```{julia} +niter = 100 +nsamples = 100 + +sampler = CCE.EnergySampler(ce;niter=niter, nsamples=100) +Xgen = rand(sampler, nsamples) +plt = plot(M, counterfactual_data, target=ce.target, xlims=(-5,5),ylims=(-5,5),cbar=false) +scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=target,shape=:star,label="X|y=$target") +``` + +## Multi-Class + ```{julia} # Setup counterfactual_data = load_multi_class() -M = fit_model(counterfactual_data, :MLP) +M = fit_model(counterfactual_data, :DeepEnsemble) target = 4 factual = 2 chosen = rand(findall(predict_label(M, counterfactual_data) .== factual)) @@ -45,7 +73,19 @@ p1 = plot(ce) ```{julia} -@objective(generator, _ + 0.1distance_l2 + 100.0distance_from_energy) +using CCE: distance_from_energy +@objective(generator, _ + 0.1distance_l2 + 10.0distance_from_energy) ce = generate_counterfactual(x, target, counterfactual_data, M, generator) p2 = plot(ce) +``` + + +```{julia} +using CCE: distance_from_targets +@objective( + generator, + _ + 0.1distance_l2 + 1.0distance_from_energy + 10.0distance_from_targets +) +ce = generate_counterfactual(x, target, counterfactual_data, M, generator) +p3 = plot(ce) ``` \ No newline at end of file diff --git a/src/penalties.jl b/src/penalties.jl index a3f35a9b3291f0919584e32ff01805ef22ca009d..5f85af406725ad258610b65a54ce8ecd371f13c2 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -59,4 +59,25 @@ function distance_from_energy( return loss -end \ No newline at end of file +end + +function distance_from_targets( + counterfactual_explanation::AbstractCounterfactualExplanation; + n::Int=100, agg=mean +) + target_samples = counterfactual_explanation.data.X |> + X -> X[:,rand(1:end,n)] + x′ = CounterfactualExplanations.counterfactual(counterfactual_explanation) + loss = map(eachslice(x′, dims=3)) do x + x = Matrix(x) + Δ = map(eachcol(target_samples)) do xsample + norm(x - xsample) + end + return mean(Δ) + end + loss = agg(loss) + + return loss + +end + diff --git a/src/sampling.jl b/src/sampling.jl index 2cf346ca70ac22e307080bf1f06c0e44738e8396..29ac4c9aedb8bcb409fd7f543cf9ceaf1516c394 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -1,8 +1,9 @@ using CounterfactualExplanations using Distributions +using Flux using JointEnergyModels -(model::AbstractFittedModel)(x) = logits(model, x) +(model::AbstractFittedModel)(x) = log.(CounterfactualExplanations.predict_proba(model, nothing, x)) mutable struct EnergySampler ce::CounterfactualExplanation @@ -28,7 +29,7 @@ function EnergySampler( # Fit: i = get_target_index(data.y_levels, ce.target) - buffer = sampler(model.model, opt, (size(data.X, 1), nsamples); niter=niter, y=i) + buffer = sampler(model, opt, (size(data.X, 1), nsamples); niter=niter, y=i) return EnergySampler(ce, sampler, opt, buffer) end