From 0c057ef953c30c67b0eb7ede7b934a6d1959c1e0 Mon Sep 17 00:00:00 2001 From: pat-alt <altmeyerpat@gmail.com> Date: Fri, 17 Mar 2023 14:55:41 +0100 Subject: [PATCH] more on sampling --- Manifest.toml | 10 ++++++++-- Project.toml | 2 ++ notebooks/conformal.qmd | 11 +++++------ notebooks/fidelity.qmd | 44 +++++++++++++++++++++++++++++++++++++++++ src/CCE.jl | 7 ++++++- src/generator.jl | 0 src/losses.jl | 0 src/penalties.jl | 43 ++++++++++++++++++++++++++++++++++++++++ src/sampling.jl | 43 ++++++++++++++++++++++++++++++++++++++++ 9 files changed, 151 insertions(+), 9 deletions(-) create mode 100644 notebooks/fidelity.qmd create mode 100644 src/generator.jl create mode 100644 src/losses.jl create mode 100644 src/penalties.jl create mode 100644 src/sampling.jl diff --git a/Manifest.toml b/Manifest.toml index b0a164f3..d2e997cf 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.5" manifest_format = "2.0" -project_hash = "c05645661b945bea85f24ae44163718eb7f430c7" +project_hash = "d4865488d9d741cc2d833e66a476d21fceebe412" [[deps.AbstractFFTs]] deps = ["ChainRulesCore", "LinearAlgebra"] @@ -275,7 +275,7 @@ uuid = "ed09eef8-17a6-5b46-8889-db040fac31e3" version = "0.3.2" [[deps.ConformalPrediction]] -deps = ["CategoricalArrays", "MLJBase", "MLJModelInterface", "NaturalSort", "Plots", "StatsBase"] +deps = ["CategoricalArrays", "ChainRules", "Flux", "LinearAlgebra", "MLJBase", "MLJModelInterface", "NaturalSort", "Plots", "StatsBase"] path = "../ConformalPrediction.jl" uuid = "98bfc277-1877-43dc-819b-a3e38c30242f" version = "0.1.6" @@ -767,6 +767,12 @@ git-tree-sha1 = "84b10656a41ef564c39d2d477d7236966d2b5683" uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" version = "1.12.0" +[[deps.JointEnergyModels]] +deps = ["Distributions", "Flux", "StatsBase"] +path = "../JointEnergyModels.jl" +uuid = "48c56d24-211d-4463-bbc0-7a701b291131" +version = "0.1.0" + [[deps.JpegTurbo_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] git-tree-sha1 = "6f2675ef130a300a112286de91973805fcc5ffbc" diff --git a/Project.toml b/Project.toml index 62158dab..a6f587df 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,9 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f" CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +JointEnergyModels = "48c56d24-211d-4463-bbc0-7a701b291131" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" diff --git a/notebooks/conformal.qmd b/notebooks/conformal.qmd index 1dfdff98..f592ac3d 100644 --- a/notebooks/conformal.qmd +++ b/notebooks/conformal.qmd @@ -120,21 +120,20 @@ plot(p1, p2, size=(800,320)) opt = Descent(0.01) ordered_names = [ "Generic (γ=0.5)", - "Generic (γ=0.9)", "Conformal (λ₂=1)", "Conformal (λ₂=10)" ] loss_fun = Objectives.logitbinarycrossentropy +generator = GenericGenerator(opt = opt) # Generators: generators = Dict( - ordered_names[1] => GenericGenerator(opt = opt, decision_threshold=0.5), - ordered_names[2] => GenericGenerator(opt = opt, decision_threshold=0.9), - ordered_names[3] => CCE.ConformalGenerator(loss=loss_fun, opt=opt, λ=[0.1,1]), - ordered_names[4] => CCE.ConformalGenerator(loss=loss_fun, opt=opt, λ=[0.1,10]), + ordered_names[1] => generator, + ordered_names[2] => deepcopy(generator) |> gen -> @objective(gen, _ + 0.1distance_l2 + 1.0set_size_penalty), + ordered_names[3] => deepcopy(generator) |> gen -> @objective(gen, _ + 0.1distance_l2 + 10.0set_size_penalty), ) -counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen; initialization=:identity) for (name, gen) in generators]) +counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen; initialization=:identity, converge_when=:generator_conditions, gradient_tol=1e-3) for (name, gen) in generators]) # Plots: plts = [] diff --git a/notebooks/fidelity.qmd b/notebooks/fidelity.qmd new file mode 100644 index 00000000..7a57cc0b --- /dev/null +++ b/notebooks/fidelity.qmd @@ -0,0 +1,44 @@ +```{julia} +using CCE +using ConformalPrediction +using CounterfactualExplanations +using CounterfactualExplanations.Data +using CounterfactualExplanations.Objectives +using Distributions +using Flux +using JointEnergyModels +using LinearAlgebra +using MLJBase +using MLJFlux +using Plots +``` + +# Fidelity Measures + +```{julia} +# Setup +counterfactual_data = load_multi_class() +M = fit_model(counterfactual_data, :MLP) +target = 4 +factual = 1 +chosen = rand(findall(predict_label(M, counterfactual_data) .== factual)) +x = select_factual(counterfactual_data, chosen) +``` + + +```{julia} +niter = 10 +nsamples = 100 +plts = [] +for target in ce.data.y_levels + # Search: + generator = GenericGenerator() + ce = generate_counterfactual(x, target, counterfactual_data, M, generator) + sampler = CCE.EnergySampler(ce;niter=niter, nsamples=100) + Xgen = rand(sampler, nsamples) + plt = plot(M, counterfactual_data, target=ce.target, xlims=(-5,5),ylims=(-5,5),cbar=false) + scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=target,shape=:star,label="X|y=$target") + push!(plts, plt) +end +plot(plts..., layout=(1,length(ce.data.y_levels)), size=(length(ce.data.y_levels)*300,300)) +``` \ No newline at end of file diff --git a/src/CCE.jl b/src/CCE.jl index e0fe38fa..c8729d9a 100644 --- a/src/CCE.jl +++ b/src/CCE.jl @@ -1,8 +1,13 @@ module CCE +using CounterfactualExplanations import MLJModelInterface as MMI include("model.jl") -include("ConformalGenerator.jl") +include("penalties.jl") +include("losses.jl") +include("generator.jl") +include("sampling.jl") +# include("ConformalGenerator.jl") end diff --git a/src/generator.jl b/src/generator.jl new file mode 100644 index 00000000..e69de29b diff --git a/src/losses.jl b/src/losses.jl new file mode 100644 index 00000000..e69de29b diff --git a/src/penalties.jl b/src/penalties.jl new file mode 100644 index 00000000..611fb3e7 --- /dev/null +++ b/src/penalties.jl @@ -0,0 +1,43 @@ +using Statistics: mean + +""" + set_size_penalty(counterfactual_explanation::AbstractCounterfactualExplanation) + +Penalty for smooth conformal set size. +""" +function set_size_penalty( + counterfactual_explanation::AbstractCounterfactualExplanation; + κ::Real=1.0, temp::Real=0.5, agg=mean +) + + conf_model = counterfactual_explanation.M.model + fitresult = counterfactual_explanation.M.fitresult + X = CounterfactualExplanations.decode_state(counterfactual_explanation) + loss = map(eachslice(X, dims=3)) do x + x = Matrix(x) + if target_probs(counterfactual_explanation, x)[1] >= 0.5 + l = ConformalPrediction.smooth_size_loss( + conf_model, fitresult, x; + κ=κ, + temp=temp + )[1] + else + l = 0.0 + end + return l + end + loss = agg(loss) + + return loss + +end + +function distance_from_energy( + counterfactual_explanation::AbstractCounterfactualExplanation; + n::Int=100, retrain=false, kwargs... +) + sampler = get!(counterfactual_explanation.params, :energy_sampler) do + CCE.EnergySampler(counterfactual_explanation; kwargs...) + end + conditional_samples = +end \ No newline at end of file diff --git a/src/sampling.jl b/src/sampling.jl new file mode 100644 index 00000000..2cf346ca --- /dev/null +++ b/src/sampling.jl @@ -0,0 +1,43 @@ +using CounterfactualExplanations +using Distributions +using JointEnergyModels + +(model::AbstractFittedModel)(x) = logits(model, x) + +mutable struct EnergySampler + ce::CounterfactualExplanation + sampler::JointEnergyModels.ConditionalSampler + opt::JointEnergyModels.AbstractSamplingRule + buffer::AbstractArray +end + +function EnergySampler( + ce::CounterfactualExplanation; + opt::JointEnergyModels.AbstractSamplingRule=ImproperSGLD(), + niter::Int=100, + nsamples::Int=1000 +) + + # Setup: + model = ce.M + data = ce.data + K = length(data.y_levels) + ð’Ÿx = Normal() + ð’Ÿy = Categorical(ones(K) ./ K) + sampler = ConditionalSampler(ð’Ÿx, ð’Ÿy) + + # Fit: + i = get_target_index(data.y_levels, ce.target) + buffer = sampler(model.model, opt, (size(data.X, 1), nsamples); niter=niter, y=i) + + return EnergySampler(ce, sampler, opt, buffer) +end + +function Base.rand(sampler::EnergySampler, n::Int=100; retrain=false) + ntotal = size(sampler.buffer,2) + idx = rand(1:ntotal, n) + if !retrain + X = sampler.buffer[:,idx] + end + return X +end -- GitLab