diff --git a/Manifest.toml b/Manifest.toml index 06b80abd0a09ec27e3d9f94fbf4dc29493113df6..9434115488ec05f1ed00e2d6098137a8214bf046 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -785,7 +785,7 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" version = "1.12.0" [[deps.JointEnergyModels]] -deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "MLUtils", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"] +deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"] path = "../JointEnergyModels.jl" uuid = "48c56d24-211d-4463-bbc0-7a701b291131" version = "0.1.0" diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml index 567a5edf16ff6947804306d6569a8706f2fd81f7..4c077ed9a0fb768eb3181835b8c94abd02cd8374 100644 --- a/notebooks/Manifest.toml +++ b/notebooks/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.5" manifest_format = "2.0" -project_hash = "b4b125f21013c1ac841e2bb761cd8922630f9f03" +project_hash = "53181f00a1b7318c04c79643b4b4b157e79df8f9" [[deps.AbstractFFTs]] deps = ["ChainRulesCore", "LinearAlgebra"] diff --git a/notebooks/Project.toml b/notebooks/Project.toml index 9d63873a5c251704aed1ab60c1256b1f9c4c8f71..180fcaf4d57cabfdde9e9f037538959c2d3262aa 100644 --- a/notebooks/Project.toml +++ b/notebooks/Project.toml @@ -1,6 +1,5 @@ [deps] AlgebraOfGraphics = "cbdf2221-f076-402e-a563-3d30da359d67" -ECCCo = "0232c203-4013-4b0d-ad96-43e3e11ac3bf" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e" @@ -8,7 +7,9 @@ Chain = "8be319e6-bccf-4806-a6f7-6fae938471bc" ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f" CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +ECCCo = "0232c203-4013-4b0d-ad96-43e3e11ac3bf" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" JointEnergyModels = "48c56d24-211d-4463-bbc0-7a701b291131" diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 0939ebd0ade4958b1a52644ae795d66309a36c21..b4fb8b774215fbe9c035a9d9778bad5ae7ded2a4 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -32,7 +32,7 @@ First, let's create a couple of image classifier architectures: # Model parameters: epochs = 100 batch_size = minimum([Int(round(n_obs/10)), 128]) -n_hidden = 50 +n_hidden = 128 activation = Flux.swish builder = MLJFlux.@builder Flux.Chain( @@ -71,7 +71,7 @@ mlp = NeuralNetworkClassifier( sampler = ConditionalSampler( ð’Ÿx, ð’Ÿy, input_size=(input_dim,), - batch_size=10 + batch_size=1 ) jem = JointEnergyClassifier( sampler; @@ -131,24 +131,15 @@ println("F1 score (test): $(round(f1,digits=3))") ``` ```{julia} -ð’Ÿx = Uniform(-1,1) -ð’Ÿy = Categorical(ones(K) ./ K) -sampler = UnconditionalSampler(ð’Ÿx; input_size=(D,)) -conditional_sampler = ConditionalSampler(ð’Ÿx, ð’Ÿy; input_size=(D,)) -opt = ImproperSGLD() -n_iter = 100 -``` - -```{julia} -# Random.seed!(1234) +Random.seed!(1234) # Set up search: -factual_label = 8 +factual_label = 9 x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) -target = 3 +target = 4 factual = predict_label(M, counterfactual_data, x)[1] γ = 0.5 -T = 100 +T = 250 η=1.0 @@ -168,8 +159,8 @@ ce_jsma = generate_counterfactual( ) # ECCCo: -λ=[0.0,1.0,10.0] -temp=0.01 +λ=[0.1,0.1,0.1] +temp=0.1 # Generate counterfactual using ECCCo generator: generator = ECCCoGenerator( diff --git a/src/penalties.jl b/src/penalties.jl index e9af36f125a7ddc0df9005d7b35bd70e9eb9d4c9..ccd1c8a911e1136f82422b8ee4ca200b58b47b8b 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -1,4 +1,5 @@ using ChainRules: ignore_derivatives +using Distances using LinearAlgebra: norm using Statistics: mean @@ -50,7 +51,8 @@ function distance_from_energy( x′ = CounterfactualExplanations.counterfactual(ce) loss = map(eachslice(x′, dims=ndims(x′))) do x Δ = map(eachcol(conditional_samples[1])) do xsample - norm(x - xsample) + # 1 .- (x'xsample)/(norm(x)*norm(xsample)) + norm(x - xsample, 1) end return mean(Δ) end diff --git a/www/eccco_mnist.png b/www/eccco_mnist.png index 8c8f3a1869c09a98aafb4cfdb94a5b4f09bf041c..edd2100307beb05f2b55511561bad5a4df0ad9f2 100644 Binary files a/www/eccco_mnist.png and b/www/eccco_mnist.png differ