diff --git a/Manifest.toml b/Manifest.toml index b0a164f37f78170a3bac2baa6095d0cff2278896..d2e997cf1ad3de8f5cfa790e14dd6ede8e6a289b 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.5" manifest_format = "2.0" -project_hash = "c05645661b945bea85f24ae44163718eb7f430c7" +project_hash = "d4865488d9d741cc2d833e66a476d21fceebe412" [[deps.AbstractFFTs]] deps = ["ChainRulesCore", "LinearAlgebra"] @@ -275,7 +275,7 @@ uuid = "ed09eef8-17a6-5b46-8889-db040fac31e3" version = "0.3.2" [[deps.ConformalPrediction]] -deps = ["CategoricalArrays", "MLJBase", "MLJModelInterface", "NaturalSort", "Plots", "StatsBase"] +deps = ["CategoricalArrays", "ChainRules", "Flux", "LinearAlgebra", "MLJBase", "MLJModelInterface", "NaturalSort", "Plots", "StatsBase"] path = "../ConformalPrediction.jl" uuid = "98bfc277-1877-43dc-819b-a3e38c30242f" version = "0.1.6" @@ -767,6 +767,12 @@ git-tree-sha1 = "84b10656a41ef564c39d2d477d7236966d2b5683" uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" version = "1.12.0" +[[deps.JointEnergyModels]] +deps = ["Distributions", "Flux", "StatsBase"] +path = "../JointEnergyModels.jl" +uuid = "48c56d24-211d-4463-bbc0-7a701b291131" +version = "0.1.0" + [[deps.JpegTurbo_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] git-tree-sha1 = "6f2675ef130a300a112286de91973805fcc5ffbc" diff --git a/Project.toml b/Project.toml index 62158dabd1a8697e13ebf6213ac16b140d13e3dc..a6f587dfbc4eadb897077e1c0bc78777a9af114e 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,9 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f" CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +JointEnergyModels = "48c56d24-211d-4463-bbc0-7a701b291131" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" diff --git a/notebooks/conformal.qmd b/notebooks/conformal.qmd index 1dfdff9849fe6132a4c23876bee4bd30000d5c52..f592ac3d69214798adddeeb02c169f186ee2e69d 100644 --- a/notebooks/conformal.qmd +++ b/notebooks/conformal.qmd @@ -120,21 +120,20 @@ plot(p1, p2, size=(800,320)) opt = Descent(0.01) ordered_names = [ "Generic (γ=0.5)", - "Generic (γ=0.9)", "Conformal (λ₂=1)", "Conformal (λ₂=10)" ] loss_fun = Objectives.logitbinarycrossentropy +generator = GenericGenerator(opt = opt) # Generators: generators = Dict( - ordered_names[1] => GenericGenerator(opt = opt, decision_threshold=0.5), - ordered_names[2] => GenericGenerator(opt = opt, decision_threshold=0.9), - ordered_names[3] => CCE.ConformalGenerator(loss=loss_fun, opt=opt, λ=[0.1,1]), - ordered_names[4] => CCE.ConformalGenerator(loss=loss_fun, opt=opt, λ=[0.1,10]), + ordered_names[1] => generator, + ordered_names[2] => deepcopy(generator) |> gen -> @objective(gen, _ + 0.1distance_l2 + 1.0set_size_penalty), + ordered_names[3] => deepcopy(generator) |> gen -> @objective(gen, _ + 0.1distance_l2 + 10.0set_size_penalty), ) -counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen; initialization=:identity) for (name, gen) in generators]) +counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen; initialization=:identity, converge_when=:generator_conditions, gradient_tol=1e-3) for (name, gen) in generators]) # Plots: plts = [] diff --git a/notebooks/fidelity.qmd b/notebooks/fidelity.qmd new file mode 100644 index 0000000000000000000000000000000000000000..7a57cc0b6562cacf0e1189508a262a67db50cbc7 --- /dev/null +++ b/notebooks/fidelity.qmd @@ -0,0 +1,44 @@ +```{julia} +using CCE +using ConformalPrediction +using CounterfactualExplanations +using CounterfactualExplanations.Data +using CounterfactualExplanations.Objectives +using Distributions +using Flux +using JointEnergyModels +using LinearAlgebra +using MLJBase +using MLJFlux +using Plots +``` + +# Fidelity Measures + +```{julia} +# Setup +counterfactual_data = load_multi_class() +M = fit_model(counterfactual_data, :MLP) +target = 4 +factual = 1 +chosen = rand(findall(predict_label(M, counterfactual_data) .== factual)) +x = select_factual(counterfactual_data, chosen) +``` + + +```{julia} +niter = 10 +nsamples = 100 +plts = [] +for target in ce.data.y_levels + # Search: + generator = GenericGenerator() + ce = generate_counterfactual(x, target, counterfactual_data, M, generator) + sampler = CCE.EnergySampler(ce;niter=niter, nsamples=100) + Xgen = rand(sampler, nsamples) + plt = plot(M, counterfactual_data, target=ce.target, xlims=(-5,5),ylims=(-5,5),cbar=false) + scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=target,shape=:star,label="X|y=$target") + push!(plts, plt) +end +plot(plts..., layout=(1,length(ce.data.y_levels)), size=(length(ce.data.y_levels)*300,300)) +``` \ No newline at end of file diff --git a/src/CCE.jl b/src/CCE.jl index e0fe38fa371838c1d18b50807ab0dd3004894e45..c8729d9aa4746a0374b5bc86e2a57cd087d9bc9f 100644 --- a/src/CCE.jl +++ b/src/CCE.jl @@ -1,8 +1,13 @@ module CCE +using CounterfactualExplanations import MLJModelInterface as MMI include("model.jl") -include("ConformalGenerator.jl") +include("penalties.jl") +include("losses.jl") +include("generator.jl") +include("sampling.jl") +# include("ConformalGenerator.jl") end diff --git a/src/generator.jl b/src/generator.jl new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/losses.jl b/src/losses.jl new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/penalties.jl b/src/penalties.jl new file mode 100644 index 0000000000000000000000000000000000000000..611fb3e799ef3d27f654c5009ce2df5a021701bd --- /dev/null +++ b/src/penalties.jl @@ -0,0 +1,43 @@ +using Statistics: mean + +""" + set_size_penalty(counterfactual_explanation::AbstractCounterfactualExplanation) + +Penalty for smooth conformal set size. +""" +function set_size_penalty( + counterfactual_explanation::AbstractCounterfactualExplanation; + κ::Real=1.0, temp::Real=0.5, agg=mean +) + + conf_model = counterfactual_explanation.M.model + fitresult = counterfactual_explanation.M.fitresult + X = CounterfactualExplanations.decode_state(counterfactual_explanation) + loss = map(eachslice(X, dims=3)) do x + x = Matrix(x) + if target_probs(counterfactual_explanation, x)[1] >= 0.5 + l = ConformalPrediction.smooth_size_loss( + conf_model, fitresult, x; + κ=κ, + temp=temp + )[1] + else + l = 0.0 + end + return l + end + loss = agg(loss) + + return loss + +end + +function distance_from_energy( + counterfactual_explanation::AbstractCounterfactualExplanation; + n::Int=100, retrain=false, kwargs... +) + sampler = get!(counterfactual_explanation.params, :energy_sampler) do + CCE.EnergySampler(counterfactual_explanation; kwargs...) + end + conditional_samples = +end \ No newline at end of file diff --git a/src/sampling.jl b/src/sampling.jl new file mode 100644 index 0000000000000000000000000000000000000000..2cf346ca70ac22e307080bf1f06c0e44738e8396 --- /dev/null +++ b/src/sampling.jl @@ -0,0 +1,43 @@ +using CounterfactualExplanations +using Distributions +using JointEnergyModels + +(model::AbstractFittedModel)(x) = logits(model, x) + +mutable struct EnergySampler + ce::CounterfactualExplanation + sampler::JointEnergyModels.ConditionalSampler + opt::JointEnergyModels.AbstractSamplingRule + buffer::AbstractArray +end + +function EnergySampler( + ce::CounterfactualExplanation; + opt::JointEnergyModels.AbstractSamplingRule=ImproperSGLD(), + niter::Int=100, + nsamples::Int=1000 +) + + # Setup: + model = ce.M + data = ce.data + K = length(data.y_levels) + ð’Ÿx = Normal() + ð’Ÿy = Categorical(ones(K) ./ K) + sampler = ConditionalSampler(ð’Ÿx, ð’Ÿy) + + # Fit: + i = get_target_index(data.y_levels, ce.target) + buffer = sampler(model.model, opt, (size(data.X, 1), nsamples); niter=niter, y=i) + + return EnergySampler(ce, sampler, opt, buffer) +end + +function Base.rand(sampler::EnergySampler, n::Int=100; retrain=false) + ntotal = size(sampler.buffer,2) + idx = rand(1:ntotal, n) + if !retrain + X = sampler.buffer[:,idx] + end + return X +end