Skip to content
Snippets Groups Projects
Commit 9f225a48 authored by Pat Alt's avatar Pat Alt
Browse files

MNIST

parent d953561f
Branches 38-laplace
No related tags found
No related merge requests found
Showing
with 130 additions and 17 deletions
artifacts/results/images/mnist_generated_JEM Ensemble.png

159 KiB

artifacts/results/images/mnist_generated_JEM.png

321 KiB

artifacts/results/images/mnist_generated_MLP Ensemble.png

335 KiB

artifacts/results/images/mnist_generated_MLP.png

325 KiB

acc,precision,f1score,mod_name
0.91,0.9103128613256043,0.9085266898485427,JEM Ensemble
0.9423,0.9418035808913033,0.9415726345973308,MLP
0.9439,0.943391966967219,0.9432116151207373,MLP Ensemble
0.8748,0.8798390712341672,0.8728379219312089,JEM
File added
No preview for this file type
No preview for this file type
No preview for this file type
......@@ -138,7 +138,7 @@ end
```{julia}
# Data:
n_obs = 1000
n_obs = 10000
counterfactual_data = load_mnist(n_obs)
counterfactual_data.X = pre_process.(counterfactual_data.X)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
......@@ -242,13 +242,11 @@ Serialization.serialize(joinpath(output_path,"mnist_models.jls"), model_dict)
```
```{julia}
for (mod_name, mod) in model_dict
# Plot generated samples:
# Plot:
if mod.model.model isa JointEnergyClassifier
sampler = mod.model.model.jem.sampler
elseif mod.model.model.model isa JointEnergyClassifier
sampler = mod.model.model.model.jem.sampler
for (mod_name, mod) in model_dict
if ECCCo._has_sampler(mod)
sampler = ECCCo._get_sampler(mod)
else
K = length(counterfactual_data.y_levels)
input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1))
......@@ -275,14 +273,90 @@ for (mod_name, mod) in model_dict
plts = [plts..., plt]
end
plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1), plot_title=mod_name)
savefig(plt, joinpath(output_images_path, "mnist_generated_$(mod_name).png"))
display(plt)
end
```
```{julia}
# Evaluate models:
measure = Dict(
:f1score => multiclass_f1score,
:acc => accuracy,
:precision => multiclass_precision
)
model_performance = DataFrame()
for (mod_name, mod) in model_dict
# Test performance:
test_data = load_mnist_test()
test_data.X = pre_process.(test_data.X)
f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data)
println("F1 score (test): $(round(f1,digits=3))")
_perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure)))
_perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
_perf.mod_name .= mod_name
model_performance = vcat(model_performance, _perf)
end
Serialization.serialize(joinpath(output_path,"mnist_model_performance.jls"), model_performance)
CSV.write(joinpath(output_path, "mnist_model_performance.csv"), model_performance)
model_performance
```
```{julia}
Random.seed!(123)
# Set up search:
factual = 8
x = reshape(counterfactual_data.X[:,rand(findall(labels.==factual))],input_dim,1)
target = 3
γ = 0.9
T = 100
# ECCCo:
λ=[0.5,0.1,0.5]
temp=0.1
η=0.01
# Generate counterfactuals using ECCCo generator:
generator = ECCCoGenerator(
λ=λ,
temp=temp,
opt=Flux.Optimise.Adam(η),
)
ces = Dict()
for (mod_name, mod) in model_dict
ce = generate_counterfactual(
x, target, counterfactual_data, mod, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
)
ces[mod_name] = ce
end
# Plot:
p1 = Plots.plot(
convert2image(MNIST, reshape(x,28,28)),
axis=nothing,
size=(img_height, img_height),
title="Factual"
)
plts = [p1]
for (_name,ce) in ces
x = CounterfactualExplanations.counterfactual(ce)
_phat = target_probs(ce)
_title = "$_name (p̂=$(round(_phat[1]; digits=3)))"
plt = Plots.plot(
convert2image(MNIST, reshape(x,28,28)),
axis=nothing,
size=(img_height, img_height),
title=_title
)
plts = [plts..., plt]
end
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
display(plt)
```
```{julia}
......
......@@ -39,6 +39,7 @@ setup_notebooks = quote
Random.seed!(2023)
www_path = "www"
output_path = "artifacts/results"
output_images_path = "artifacts/results/images"
img_height = 300
end;
\ No newline at end of file
No preview for this file type
......@@ -266,7 +266,8 @@ As noted by \citet{guidotti2022counterfactual}, these distance-based measures ar
\item BatchNorm does not seem compatible with JEM
\item Coverage and temperature impacts CCE in somewhat unpredictable ways
\item It seems that models that are not explicitly trained for generative task, still learn it implictly
\item Batch size seems to impact quality of generated samples
\item Batch size seems to impact quality of generated samples (at inference, but not so much during JEM training)
\item ECCCo is sensitive to optimizer (Adam works well), learning rate and distance metric (l1 works well)
\end{itemize}
\section{Discussion}
......
......@@ -28,7 +28,7 @@ function ECCCoGenerator(;
function _set_size_penalty(ce::AbstractCounterfactualExplanation)
return ECCCo.set_size_penalty(ce; κ=κ, temp=temp)
end
_penalties = [Objectives.distance_l2, _set_size_penalty, ECCCo.distance_from_energy]
_penalties = [Objectives.distance_l1, _set_size_penalty, ECCCo.distance_from_energy]
λ = λ isa AbstractFloat ? [0.0, λ, λ] : λ
return Generator(; penalty=_penalties, λ=λ, opt=opt, kwargs...)
end
......
......@@ -60,6 +60,33 @@ function _outdim(fitresult)
return outdim
end
"""
_get_sampler(model::ConformalModel)
Private helper function that extracts the sampler from a fitted model.
"""
function _get_sampler(model::ConformalModel)
_mod = model.model
if _mod.model isa MLJEnsembles.EitherEnsembleModel
_mod = _mod.model
end
if _mod.model isa JointEnergyClassifier
sampler = _mod.model.sampler
else
sampler = false
end
return sampler
end
"""
_has_sampler(model::ConformalModel)
Private helper function that checks if a fitted model has a sampler.
"""
function _has_sampler(model::ConformalModel)
return !(_get_sampler(model) isa Bool)
end
"""
ConformalModel(model, fitresult=nothing; likelihood::Union{Nothing,Symbol}=nothing)
......
......@@ -51,7 +51,6 @@ function distance_from_energy(
x′ = CounterfactualExplanations.counterfactual(ce)
loss = map(eachslice(x′, dims=ndims(x′))) do x
Δ = map(eachcol(conditional_samples[1])) do xsample
# 1 .- (x'xsample)/(norm(x)*norm(xsample))
norm(x - xsample, 1)
end
return mean(Δ)
......
......@@ -43,8 +43,8 @@ function EnergySampler(
@assert y data.y_levels || y 1:length(data.y_levels)
if model.model.model isa JointEnergyClassifier
sampler = model.model.model.jem.sampler
if ECCCo._has_sampler(model)
sampler = ECCCo._get_sampler(model)
else
K = length(data.y_levels)
input_size = size(selectdim(data.X, ndims(data.X), 1))
......@@ -57,8 +57,10 @@ function EnergySampler(
# Initiate:
energy_sampler = EnergySampler(model, data, sampler, opt, nothing, nothing)
# Generate conditional samples:
generate_samples!(energy_sampler, nsamples, yidx; niter=niter)
# Generate conditional samples (one at a time):
for i in 1:nsamples
generate_samples!(energy_sampler, 1, yidx; niter=niter)
end
return energy_sampler
end
......@@ -106,7 +108,11 @@ end
Generates `n` samples from `EnergySampler` for conditioning value `y`. Assigns samples and conditioning value to `EnergySampler`.
"""
function generate_samples!(e::EnergySampler, n::Int, y::Int; niter::Int=100)
e.buffer = generate_samples(e, n, y; niter=niter)
if isnothing(e.buffer)
e.buffer = generate_samples(e, n, y; niter=niter)
else
e.buffer = cat(e.buffer, generate_samples(e, n, y; niter=niter), dims=ndims(e.buffer))
end
e.yidx = y
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment