diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml index 400facb7be5d41186be3476bfe291d7420d08ee1..567a5edf16ff6947804306d6569a8706f2fd81f7 100644 --- a/notebooks/Manifest.toml +++ b/notebooks/Manifest.toml @@ -17,9 +17,9 @@ version = "0.4.4" [[deps.Accessors]] deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "StaticArrays", "Test"] -git-tree-sha1 = "beabc31fa319f9de4d16372bff31b4801e43d32c" +git-tree-sha1 = "c7dddee3f32ceac12abd9a21cd0c4cb489f230d2" uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.28" +version = "0.1.29" [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] @@ -1052,7 +1052,7 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" version = "1.12.0" [[deps.JointEnergyModels]] -deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "MLUtils", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"] +deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"] path = "../../JointEnergyModels.jl" uuid = "48c56d24-211d-4463-bbc0-7a701b291131" version = "0.1.0" diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 8a7f179f1dc24a1e55975659fc4c2a5b45d6ffa6..0939ebd0ade4958b1a52644ae795d66309a36c21 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -94,7 +94,7 @@ mlp_ens = EnsembleModel(model=mlp, n=50) ```{julia} cov = .95 -conf_model = conformal_model(jem; method=:adaptive_inductive, coverage=cov) +conf_model = conformal_model(jem; method=:simple_inductive, coverage=cov) mach = machine(conf_model, X, labels) fit!(mach) M = ECCCo.ConformalModel(mach.model, mach.fitresult) @@ -130,6 +130,15 @@ f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data) println("F1 score (test): $(round(f1,digits=3))") ``` +```{julia} +ð’Ÿx = Uniform(-1,1) +ð’Ÿy = Categorical(ones(K) ./ K) +sampler = UnconditionalSampler(ð’Ÿx; input_size=(D,)) +conditional_sampler = ConditionalSampler(ð’Ÿx, ð’Ÿy; input_size=(D,)) +opt = ImproperSGLD() +n_iter = 100 +``` + ```{julia} # Random.seed!(1234) @@ -141,15 +150,17 @@ factual = predict_label(M, counterfactual_data, x)[1] γ = 0.5 T = 100 +η=1.0 + # Generate counterfactual using generic generator: -generator = GenericGenerator(opt=Flux.Optimise.Adam(),) +generator = GenericGenerator() ce_wachter = generate_counterfactual( x, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, ) -generator = GreedyGenerator(η=1.0) +generator = GreedyGenerator(η=η) ce_jsma = generate_counterfactual( x, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, @@ -157,14 +168,13 @@ ce_jsma = generate_counterfactual( ) # ECCCo: -λ=[0.0,1.0] -temp=0.5 +λ=[0.0,1.0,10.0] +temp=0.01 -# Generate counterfactual using CCE generator: -generator = CCEGenerator( +# Generate counterfactual using ECCCo generator: +generator = ECCCoGenerator( λ=λ, temp=temp, - opt=Flux.Optimise.Adam(), ) ce_conformal = generate_counterfactual( x, target, counterfactual_data, M, generator; @@ -173,11 +183,11 @@ ce_conformal = generate_counterfactual( converge_when=:generator_conditions, ) -# Generate counterfactual using CCE generator: -generator = CCEGenerator( +# Generate counterfactual using ECCCo generator: +generator = ECCCoGenerator( λ=λ, temp=temp, - opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0), + opt=CounterfactualExplanations.Generators.JSMADescent(η=η), ) ce_conformal_jsma = generate_counterfactual( x, target, counterfactual_data, M, generator; @@ -196,7 +206,7 @@ p1 = Plots.plot( plts = [p1] ces = [ce_wachter, ce_conformal, ce_jsma, ce_conformal_jsma] -_names = ["Wachter", "CCE", "JSMA", "CCE-JSMA"] +_names = ["Wachter", "ECCCo", "JSMA", "ECCCo-JSMA"] for x in zip(ces, _names) ce, _name = (x[1],x[2]) x = CounterfactualExplanations.counterfactual(ce) @@ -212,7 +222,7 @@ for x in zip(ces, _names) end plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) display(plt) -savefig(plt, joinpath(www_path, "cce_mnist.png")) +savefig(plt, joinpath(www_path, "eccco_mnist.png")) ``` ```{julia} @@ -226,8 +236,6 @@ factual = predict_label(M, counterfactual_data, x)[1] γ = 0.5 T = 100 -η=0.1 - # Generate counterfactual using generic generator: generator = GenericGenerator(opt=Flux.Optimise.Adam(),) ce_wachter = generate_counterfactual( @@ -236,7 +244,7 @@ ce_wachter = generate_counterfactual( initialization=:identity, ) -generator = GreedyGenerator(η=η) +generator = GreedyGenerator(η=1.0) ce_jsma = generate_counterfactual( x, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, @@ -244,11 +252,11 @@ ce_jsma = generate_counterfactual( ) # ECCCo: -λ=[0.0,0.0,10.0] +λ=[0.0,1.0] temp=0.5 -# Generate counterfactual using ECCCo generator: -generator = ECCCoGenerator( +# Generate counterfactual using CCE generator: +generator = CCEGenerator( λ=λ, temp=temp, opt=Flux.Optimise.Adam(), @@ -260,11 +268,11 @@ ce_conformal = generate_counterfactual( converge_when=:generator_conditions, ) -# Generate counterfactual using ECCCo generator: -generator = ECCCoGenerator( +# Generate counterfactual using CCE generator: +generator = CCEGenerator( λ=λ, temp=temp, - opt=CounterfactualExplanations.Generators.JSMADescent(η=η), + opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0), ) ce_conformal_jsma = generate_counterfactual( x, target, counterfactual_data, M, generator; @@ -283,7 +291,7 @@ p1 = Plots.plot( plts = [p1] ces = [ce_wachter, ce_conformal, ce_jsma, ce_conformal_jsma] -_names = ["Wachter", "ECCCo", "JSMA", "ECCCo-JSMA"] +_names = ["Wachter", "CCE", "JSMA", "CCE-JSMA"] for x in zip(ces, _names) ce, _name = (x[1],x[2]) x = CounterfactualExplanations.counterfactual(ce) @@ -299,7 +307,7 @@ for x in zip(ces, _names) end plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) display(plt) -savefig(plt, joinpath(www_path, "eccco_mnist.png")) +savefig(plt, joinpath(www_path, "cce_mnist.png")) ``` ```{julia} diff --git a/paper/paper.pdf b/paper/paper.pdf index 222cec3cdb4a1bfb1966a554996d478b458fc381..eddb381562b5e51c254fb4629f64e5cc0707647a 100644 Binary files a/paper/paper.pdf and b/paper/paper.pdf differ diff --git a/src/penalties.jl b/src/penalties.jl index 95c96b4109cba3bf5f40d4c8f7ae03e44bfefceb..e9af36f125a7ddc0df9005d7b35bd70e9eb9d4c9 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -36,13 +36,13 @@ end function distance_from_energy( ce::AbstractCounterfactualExplanation; - n::Int=1, niter=200, from_buffer=true, agg=mean, kwargs... + n::Int=10, niter=200, from_buffer=true, agg=mean, kwargs... ) conditional_samples = [] ignore_derivatives() do _dict = ce.params if !(:energy_sampler ∈ collect(keys(_dict))) - _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=1000, kwargs...) + _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...) end sampler = _dict[:energy_sampler] push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer)) diff --git a/src/sampling.jl b/src/sampling.jl index 50a170283a65f6e0ad64faae58084913898f0a95..023c3d87fd7a138fc21431f4fd603aa19ace771e 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -57,10 +57,8 @@ function EnergySampler( # Initiate: energy_sampler = EnergySampler(model, data, sampler, opt, nothing, nothing) - # Generate samples: - chain = model.model.model.jem.chain - rule = model.model.model.jem.sampling_rule - energy_sampler.sampler(chain, rule; niter=niter, n_samples=nsamples, y=yidx) + # Generate conditional samples: + generate_samples!(energy_sampler, nsamples, yidx; niter=niter) return energy_sampler end @@ -86,19 +84,43 @@ function EnergySampler( return EnergySampler(model, data, y; kwrgs...) end +""" + generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100) + +Generates `n` samples from `EnergySampler` for conditioning value `y`. +""" +function generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100) + + # Generate samples: + chain = e.model.fitresult[1] + rule = e.opt + xsamples = e.sampler(chain, rule; niter=niter, n_samples=n, y=y) + + return xsamples +end + +""" + generate_samples!(e::EnergySampler, n::Int, y::Int; niter::Int=100) + +Generates `n` samples from `EnergySampler` for conditioning value `y`. Assigns samples and conditioning value to `EnergySampler`. +""" +function generate_samples!(e::EnergySampler, n::Int, y::Int; niter::Int=100) + e.buffer = generate_samples(e, n, y; niter=niter) + e.yidx = y +end + """ Base.rand(sampler::EnergySampler, n::Int=100; retrain=false) Overloads the `rand` method to randomly draw `n` samples from `EnergySampler`. """ function Base.rand(sampler::EnergySampler, n::Int=100; from_buffer=true, niter::Int=100) - ntotal = size(sampler.sampler.buffer)[end] + ntotal = size(sampler.buffer, 2) idx = rand(1:ntotal, n) if from_buffer - X = sampler.sampler.buffer[:, idx] + X = sampler.buffer[:, idx] else - chain = sampler.model.fitresult[1] - X = sampler.sampler(chain, sampler.opt; niter=niter, n_samples=n, y=sampler.yidx) + X = generate_samples(sampler, n, sampler.yidx; niter=niter) end return X end diff --git a/www/eccco_mnist.png b/www/eccco_mnist.png index 81cecbe4d0f947bf55ef08cd7f2db35f318337c2..8c8f3a1869c09a98aafb4cfdb94a5b4f09bf041c 100644 Binary files a/www/eccco_mnist.png and b/www/eccco_mnist.png differ