diff --git a/notebooks/fidelity.qmd b/notebooks/fidelity.qmd index 2b7e384ace148025834ba79617de24535c9959ec..def8c2757a419438923f3367cd976c2ffaee6ca4 100644 --- a/notebooks/fidelity.qmd +++ b/notebooks/fidelity.qmd @@ -20,25 +20,32 @@ using Plots counterfactual_data = load_multi_class() M = fit_model(counterfactual_data, :MLP) target = 4 -factual = 1 +factual = 2 chosen = rand(findall(predict_label(M, counterfactual_data) .== factual)) x = select_factual(counterfactual_data, chosen) -``` +# Search: +generator = GenericGenerator(opt=Descent(0.01)) +ce = generate_counterfactual(x, target, counterfactual_data, M, generator) +``` ```{julia} -niter = 10000 +niter = 100 nsamples = 100 -plts = [] -for target in ce.data.y_levels - # Search: - generator = GenericGenerator() - ce = generate_counterfactual(x, target, counterfactual_data, M, generator) - sampler = CCE.EnergySampler(ce;niter=niter, nsamples=100) - Xgen = rand(sampler, nsamples) - plt = plot(M, counterfactual_data, target=ce.target, xlims=(-5,5),ylims=(-5,5),cbar=false) - scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=target,shape=:star,label="X|y=$target") - push!(plts, plt) -end -plot(plts..., layout=(1,length(ce.data.y_levels)), size=(length(ce.data.y_levels)*300,300)) + +sampler = CCE.EnergySampler(ce;niter=niter, nsamples=100) +Xgen = rand(sampler, nsamples) +plt = plot(M, counterfactual_data, target=ce.target, xlims=(-5,5),ylims=(-5,5),cbar=false) +scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=target,shape=:star,label="X|y=$target") +``` + +```{julia} +p1 = plot(ce) +``` + + +```{julia} +@objective(generator, _ + 0.1distance_l2 + 100.0distance_from_energy) +ce = generate_counterfactual(x, target, counterfactual_data, M, generator) +p2 = plot(ce) ``` \ No newline at end of file diff --git a/src/penalties.jl b/src/penalties.jl index 611fb3e799ef3d27f654c5009ce2df5a021701bd..a3f35a9b3291f0919584e32ff01805ef22ca009d 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -1,3 +1,5 @@ +using ChainRules: ignore_derivatives +using LinearAlgebra: norm using Statistics: mean """ @@ -33,11 +35,28 @@ function set_size_penalty( end function distance_from_energy( - counterfactual_explanation::AbstractCounterfactualExplanation; - n::Int=100, retrain=false, kwargs... + counterfactual_explanation::AbstractCounterfactualExplanation; + n::Int=100, retrain=false, agg=mean, kwargs... ) - sampler = get!(counterfactual_explanation.params, :energy_sampler) do - CCE.EnergySampler(counterfactual_explanation; kwargs...) + conditional_samples = [] + ignore_derivatives() do + _dict = counterfactual_explanation.params + if !(:energy_sampler ∈ collect(keys(_dict))) + _dict[:energy_sampler] = CCE.EnergySampler(counterfactual_explanation; kwargs...) + end + sampler = _dict[:energy_sampler] + push!(conditional_samples, rand(sampler, n; retrain=retrain)) + end + x′ = CounterfactualExplanations.counterfactual(counterfactual_explanation) + loss = map(eachslice(x′, dims=3)) do x + x = Matrix(x) + Δ = map(eachcol(conditional_samples[1])) do xsample + norm(x - xsample) + end + return mean(Δ) end - conditional_samples = + loss = agg(loss) + + return loss + end \ No newline at end of file diff --git a/src/sampling.jl b/src/sampling.jl index 8a6625ee9d1b02ba37b8263bf3772d2ee9c6e0ce..2cf346ca70ac22e307080bf1f06c0e44738e8396 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -13,7 +13,7 @@ end function EnergySampler( ce::CounterfactualExplanation; - opt::JointEnergyModels.AbstractSamplingRule=SGLD(), + opt::JointEnergyModels.AbstractSamplingRule=ImproperSGLD(), niter::Int=100, nsamples::Int=1000 )