Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
E
ECCCo-jl
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Patrick Altmeyer
ECCCo-jl
Commits
8abeacbf
Commit
8abeacbf
authored
1 year ago
by
pat-alt
Browse files
Options
Downloads
Patches
Plain Diff
added code for real-world data experiments
parent
5ea52468
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
artifacts/results/cal_housing_models.jls
+0
-0
0 additions, 0 deletions
artifacts/results/cal_housing_models.jls
notebooks/cal_housing.qmd
+202
-0
202 additions, 0 deletions
notebooks/cal_housing.qmd
notebooks/gmsc.qmd
+202
-0
202 additions, 0 deletions
notebooks/gmsc.qmd
with
404 additions
and
0 deletions
artifacts/results/cal_housing_models.jls
0 → 100644
+
0
−
0
View file @
8abeacbf
File added
This diff is collapsed.
Click to expand it.
notebooks/
re
al_
world
.qmd
→
notebooks/
c
al_
housing
.qmd
+
202
−
0
View file @
8abeacbf
...
@@ -7,18 +7,15 @@ eval(setup_notebooks)
...
@@ -7,18 +7,15 @@ eval(setup_notebooks)
```{julia}
```{julia}
# Hyper:
# Hyper:
_retrain = false
_retrain = true
_regen = false
# Data:
# Data:
n_obs = 10000
n_obs = 10000
datasets = load_tabular_data(n_obs; drop=:credit_default
)
counterfactual_data = load_california_housing(n_obs
)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
X = table(permutedims(X))
x_factual = reshape(pre_process(x_factual, noise=0.0f0), input_dim, 1)
labels = counterfactual_data.output_encoder.labels
labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X)
input_dim, n_obs = size(counterfactual_data.X)
n_digits = Int(sqrt(input_dim))
output_dim = length(unique(labels))
output_dim = length(unique(labels))
```
```
...
@@ -41,7 +38,7 @@ _finaliser = x -> x # finaliser function
...
@@ -41,7 +38,7 @@ _finaliser = x -> x # finaliser function
```{julia}
```{julia}
# JEM parameters:
# JEM parameters:
𝒟x =
Uniform(0,1
)
𝒟x =
Normal(
)
𝒟y = Categorical(ones(output_dim) ./ output_dim)
𝒟y = Categorical(ones(output_dim) ./ output_dim)
sampler = ConditionalSampler(
sampler = ConditionalSampler(
𝒟x, 𝒟y,
𝒟x, 𝒟y,
...
@@ -112,260 +109,45 @@ function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_nam
...
@@ -112,260 +109,45 @@ function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_nam
end
end
if _retrain
if _retrain
model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in models)
model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in models)
Serialization.serialize(joinpath(output_path,"
mnist
_models.jls"), model_dict)
Serialization.serialize(joinpath(output_path,"
cal_housing
_models.jls"), model_dict)
else
else
model_dict = Serialization.deserialize(joinpath(output_path,"
mnist
_models.jls"))
model_dict = Serialization.deserialize(joinpath(output_path,"
cal_housing
_models.jls"))
end
end
```
```
```{julia}
```{julia}
# Plot generated samples:
# # Evaluate models:
n_regen = 150
if _regen
# measure = Dict(
for (mod_name, mod) in model_dict
# :f1score => multiclass_f1score,
if ECCCo._has_sampler(mod)
# :acc => accuracy,
sampler = ECCCo._get_sampler(mod)
# :precision => multiclass_precision
else
# )
K = length(counterfactual_data.y_levels)
# model_performance = DataFrame()
input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1))
# for (mod_name, mod) in model_dict
𝒟x = Uniform(extrema(counterfactual_data.X)...)
# # Test performance:
𝒟y = Categorical(ones(K) ./ K)
# test_data = load_mnist_test()
sampler = ConditionalSampler(𝒟x, 𝒟y; input_size=input_size)
# _perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure)))
end
# _perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
opt = ImproperSGLD()
# _perf.mod_name .= mod_name
f(x) = logits(mod, x)
# model_performance = vcat(model_performance, _perf)
# end
_w = 1500
# Serialization.serialize(joinpath(output_path,"cal_housing_model_performance.jls"), model_performance)
plts = []
# CSV.write(joinpath(output_path, "cal_housing_model_performance.csv"), model_performance)
neach = 10
# model_performance
for i in 1:10
x = sampler(f, opt; niter=n_regen, n_samples=neach, y=i)
plts_i = []
for j in 1:size(x, 2)
xj = x[:,j]
xj = reshape(xj, (n_digits, n_digits))
plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)]
end
plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))
plts = [plts..., plt]
end
plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1), plot_title=mod_name)
savefig(plt, joinpath(output_images_path, "mnist_generated_$(mod_name).png"))
display(plt)
end
end
```
```
```{julia}
## Benchmark
# Evaluate models:
measure = Dict(
:f1score => multiclass_f1score,
:acc => accuracy,
:precision => multiclass_precision
)
model_performance = DataFrame()
for (mod_name, mod) in model_dict
# Test performance:
test_data = load_mnist_test()
_perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure)))
_perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
_perf.mod_name .= mod_name
model_performance = vcat(model_performance, _perf)
end
Serialization.serialize(joinpath(output_path,"mnist_model_performance.jls"), model_performance)
CSV.write(joinpath(output_path, "mnist_model_performance.csv"), model_performance)
model_performance
```
### Different Models
```{julia}
function _plot_eccco_mnist(
x::Union{AbstractArray, Int}=x_factual, target::Int=target;
λ=[0.5,0.1,0.5],
temp=0.1,η=0.01,
plt_order = ["MLP", "MLP Ensemble", "JEM", "JEM Ensemble"],
opt = Flux.Optimise.Adam(η),
rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG,
)
# Setup:
Random.seed!(rng)
if x isa Int
x = reshape(counterfactual_data.X[:,rand(findall(labels.==x))],input_dim,1)
end
# Generate counterfactuals using ECCCo generator:
eccco_generator = ECCCoGenerator(
λ=λ,
temp=temp,
opt=opt,
)
ces = Dict()
for (mod_name, mod) in model_dict
ce = generate_counterfactual(
x, target, counterfactual_data, mod, eccco_generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
)
ces[mod_name] = ce
end
_plt_order = map(x -> findall(collect(keys(model_dict)) .== x)[1], plt_order)
# Plot:
p1 = Plots.plot(
convert2image(MNIST, reshape(x,28,28)),
axis=nothing,
size=(img_height, img_height),
title="Factual"
)
plts = []
for (_name,ce) in ces
_x = CounterfactualExplanations.counterfactual(ce)
_phat = target_probs(ce)
_title = "$_name (p̂=$(round(_phat[1]; digits=3)))"
plt = Plots.plot(
convert2image(MNIST, reshape(_x,28,28)),
axis=nothing,
size=(img_height, img_height),
title=_title
)
plts = [plts..., plt]
end
plts = plts[_plt_order]
plts = [p1, plts...]
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
return plt, eccco_generator
end
```
```{julia}
plt, eccco_generator = _plot_eccco_mnist()
display(plt)
savefig(plt, joinpath(output_images_path, "mnist_eccco.png"))
```
### All digits
```{julia}
function plot_mnist(
factual::Int,target::Int;
generator::AbstractGenerator,
model::AbstractFittedModel=model_dict["JEM Ensemble"],
data::CounterfactualData=counterfactual_data,
rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG,
_plot_title::Bool=true,
kwargs...,
)
decision_threshold = !isdefined(kwargs, :decision_threshold) ? 0.9 : decision_threshold
max_iter = !isdefined(kwargs, :max_iter) ? 100 : max_iter
initialization = !isdefined(kwargs, :initialization) ? :identity : initialization
converge_when = !isdefined(kwargs, :converge_when) ? :generator_conditions : converge_when
x = reshape(data.X[:,rand(findall(predict_label(model, data).==factual))],input_dim,1)
ce = generate_counterfactual(
x, target, data, model, generator;
decision_threshold=decision_threshold, max_iter=max_iter,
initialization=initialization,
converge_when=converge_when,
kwargs...
)
_title = _plot_title ? "$(factual) -> $(target)" : ""
_x = CounterfactualExplanations.counterfactual(ce)
plt = Plots.plot(
convert2image(MNIST, reshape(_x,28,28)),
axis=nothing,
size=(img_height, img_height),
title=_title
)
return plt
end
```
```{julia}
if _regen
function plot_all_digits(rng=1;verbose=true,kwargs...)
plts = []
for i in 0:9
for j in 0:9
@info "Generating counterfactual for $(i) -> $(j)"
plt = plot_mnist(i,j;kwargs...,rng=rng)
!verbose || display(plt)
plts = [plts..., plt]
end
end
plt = Plots.plot(plts...; size=(img_height*10,img_height*10), layout=(10,10))
return plt
end
plt = plot_all_digits(generator=eccco_generator)
savefig(plt, joinpath(output_images_path, "mnist_eccco_all_digits.png"))
end
```
### Different Generators
```{julia}
```{julia}
# Setup:
model = model_dict["JEM Ensemble"]
# Benchmark generators:
# Benchmark generators:
generator_dict = Dict(
generator_dict = Dict(
:wachter => generic_generator,
:wachter => WachterGenerator(),
:revise => revise_generator,
:revise => REVISEGenerator(),
:greedy => greedy_generator,
:greedy => GreedyGenerator(),
:eccco => eccco_generator,
:eccco => ECCCoGenerator(),
)
ces = Dict()
for (gen_name, gen) in generator_dict
ce = generate_counterfactual(
x_factual, target, counterfactual_data, model, gen;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
)
ces[gen_name] = ce
end
plt_order = sortperm(collect(keys(ces)))
# Plot:
p1 = Plots.plot(
convert2image(MNIST, reshape(x_factual,28,28)),
axis=nothing,
size=(img_height, img_height),
title="Factual"
)
)
plts = []
for (_name,ce) in ces
_x = CounterfactualExplanations.counterfactual(ce)
_phat = target_probs(ce)
_title = "$_name (p̂=$(round(_phat[1]; digits=3)))"
plt = Plots.plot(
convert2image(MNIST, reshape(_x,28,28)),
axis=nothing,
size=(img_height, img_height),
title=_title
)
plts = [plts..., plt]
end
plts = plts[plt_order]
plts = [p1, plts...]
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
display(plt)
savefig(plt, joinpath(output_images_path, "mnist_all_generators.png"))
```
## Benchmark
```{julia}
# Measures:
# Measures:
measures = [
measures = [
CounterfactualExplanations.distance,
CounterfactualExplanations.distance,
...
@@ -380,12 +162,12 @@ bmk = benchmark(
...
@@ -380,12 +162,12 @@ bmk = benchmark(
models=model_dict,
models=model_dict,
generators=generator_dict,
generators=generator_dict,
measure=measures,
measure=measures,
suppress_training=true, dataname="
MNIST
",
suppress_training=true, dataname="
California Housing
",
n_individuals=5,
n_individuals=5,
factual=0, target=1,
factual=0, target=1,
initialization=:identity,
initialization=:identity,
)
)
CSV.write(joinpath(output_path, "
mnist
_benchmark.csv"), bmk())
CSV.write(joinpath(output_path, "
cal_housing
_benchmark.csv"), bmk())
```
```
...
@@ -416,5 +198,5 @@ plt = draw(
...
@@ -416,5 +198,5 @@ plt = draw(
facet=(; linkyaxes=:minimal)
facet=(; linkyaxes=:minimal)
)
)
display(plt)
display(plt)
save(joinpath(output_images_path, "
mnist
_benchmark.png"), plt, px_per_unit=5)
save(joinpath(output_images_path, "
cal_housing
_benchmark.png"), plt, px_per_unit=5)
```
```
\ No newline at end of file
This diff is collapsed.
Click to expand it.
notebooks/gmsc.qmd
0 → 100644
+
202
−
0
View file @
8abeacbf
```{julia}
include("notebooks/setup.jl")
eval(setup_notebooks)
```
# Real-World Data
```{julia}
# Hyper:
_retrain = true
# Data:
n_obs = 10000
counterfactual_data = load_california_housing(n_obs)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X)
output_dim = length(unique(labels))
```
First, let's create a couple of image classifier architectures:
```{julia}
# Model parameters:
epochs = 100
batch_size = minimum([Int(round(n_obs/10)), 128])
n_hidden = 64
activation = Flux.relu
builder = MLJFlux.@builder Flux.Chain(
Dense(n_in, n_hidden, activation),
Dense(n_hidden, n_out),
)
n_ens = 5 # number of models in ensemble
_loss = Flux.Losses.logitcrossentropy # loss function
_finaliser = x -> x # finaliser function
```
```{julia}
# JEM parameters:
𝒟x = Normal()
𝒟y = Categorical(ones(output_dim) ./ output_dim)
sampler = ConditionalSampler(
𝒟x, 𝒟y,
input_size=(input_dim,),
batch_size=10,
)
α = [1.0,1.0,1e-2] # penalty strengths
```
```{julia}
# Simple MLP:
mlp = NeuralNetworkClassifier(
builder=builder,
epochs=epochs,
batch_size=batch_size,
finaliser=_finaliser,
loss=_loss,
)
# Deep Ensemble:
mlp_ens = EnsembleModel(model=mlp, n=n_ens)
# Joint Energy Model:
jem = JointEnergyClassifier(
sampler;
builder=builder,
epochs=epochs,
batch_size=batch_size,
finaliser=_finaliser,
loss=_loss,
jem_training_params=(
α=α,verbosity=10,
),
sampling_steps=20,
)
# JEM with adversarial training:
jem_adv = deepcopy(jem)
# jem_adv.adv_training = true
# Deep Ensemble of Joint Energy Models:
jem_ens = EnsembleModel(model=jem, n=n_ens)
# Deep Ensemble of Joint Energy Models with adversarial training:
# jem_ens_plus = EnsembleModel(model=jem_adv, n=n_ens)
# Dictionary of models:
models = Dict(
"MLP" => mlp,
"MLP Ensemble" => mlp_ens,
"JEM" => jem,
"JEM Ensemble" => jem_ens,
# "JEM Ensemble+" => jem_ens_plus,
)
```
```{julia}
# Train models:
function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_name="model")
conf_model = conformal_model(model; method=method, coverage=cov)
mach = machine(conf_model, X, y)
@info "Begin training $mod_name."
fit!(mach)
@info "Finished training $mod_name."
M = ECCCo.ConformalModel(mach.model, mach.fitresult)
return M
end
if _retrain
model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in models)
Serialization.serialize(joinpath(output_path,"gmsc_models.jls"), model_dict)
else
model_dict = Serialization.deserialize(joinpath(output_path,"gmsc_models.jls"))
end
```
```{julia}
# # Evaluate models:
# measure = Dict(
# :f1score => multiclass_f1score,
# :acc => accuracy,
# :precision => multiclass_precision
# )
# model_performance = DataFrame()
# for (mod_name, mod) in model_dict
# # Test performance:
# test_data = load_mnist_test()
# _perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure)))
# _perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
# _perf.mod_name .= mod_name
# model_performance = vcat(model_performance, _perf)
# end
# Serialization.serialize(joinpath(output_path,"gmsc_model_performance.jls"), model_performance)
# CSV.write(joinpath(output_path, "gmsc_model_performance.csv"), model_performance)
# model_performance
```
## Benchmark
```{julia}
# Benchmark generators:
generator_dict = Dict(
:wachter => WachterGenerator(),
:revise => REVISEGenerator(),
:greedy => GreedyGenerator(),
:eccco => ECCCoGenerator(),
)
# Measures:
measures = [
CounterfactualExplanations.distance,
ECCCo.distance_from_energy,
ECCCo.distance_from_targets,
CounterfactualExplanations.Evaluation.validity,
CounterfactualExplanations.Evaluation.redundancy,
]
bmk = benchmark(
counterfactual_data;
models=model_dict,
generators=generator_dict,
measure=measures,
suppress_training=true, dataname="Californian Housing",
n_individuals=5,
factual=0, target=1,
initialization=:identity,
)
CSV.write(joinpath(output_path, "gmsc_benchmark.csv"), bmk())
```
```{julia}
@chain bmk() begin
@group_by(dataname, generator, model, variable)
@summarize(mean=mean(value),sd=std(value))
@ungroup
@filter(variable == "distance_from_energy")
end
```
```{julia}
df = @chain bmk() begin
@filter(variable in [
"distance_from_energy",
"distance_from_targets",
"distance",])
@mutate(variable = ifelse.(variable .== "distance_from_energy", "Non-Conformity", variable))
@mutate(variable = ifelse.(variable .== "distance_from_targets", "Implausibility", variable))
@mutate(variable = ifelse.(variable .== "distance", "Cost", variable))
end
plt = AlgebraOfGraphics.data(df) * visual(BoxPlot) *
mapping(:generator, :value, row=:variable, col=:model, color=:generator)
plt = draw(
plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=150, height=120),
facet=(; linkyaxes=:minimal)
)
display(plt)
save(joinpath(output_images_path, "gmsc_benchmark.png"), plt, px_per_unit=5)
```
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment