Skip to content
Snippets Groups Projects
Commit 8abeacbf authored by pat-alt's avatar pat-alt
Browse files

added code for real-world data experiments

parent 5ea52468
No related branches found
No related tags found
No related merge requests found
File added
...@@ -7,18 +7,15 @@ eval(setup_notebooks) ...@@ -7,18 +7,15 @@ eval(setup_notebooks)
```{julia} ```{julia}
# Hyper: # Hyper:
_retrain = false _retrain = true
_regen = false
# Data: # Data:
n_obs = 10000 n_obs = 10000
datasets = load_tabular_data(n_obs; drop=:credit_default) counterfactual_data = load_california_housing(n_obs)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X)) X = table(permutedims(X))
x_factual = reshape(pre_process(x_factual, noise=0.0f0), input_dim, 1)
labels = counterfactual_data.output_encoder.labels labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X) input_dim, n_obs = size(counterfactual_data.X)
n_digits = Int(sqrt(input_dim))
output_dim = length(unique(labels)) output_dim = length(unique(labels))
``` ```
...@@ -41,7 +38,7 @@ _finaliser = x -> x # finaliser function ...@@ -41,7 +38,7 @@ _finaliser = x -> x # finaliser function
```{julia} ```{julia}
# JEM parameters: # JEM parameters:
𝒟x = Uniform(0,1) 𝒟x = Normal()
𝒟y = Categorical(ones(output_dim) ./ output_dim) 𝒟y = Categorical(ones(output_dim) ./ output_dim)
sampler = ConditionalSampler( sampler = ConditionalSampler(
𝒟x, 𝒟y, 𝒟x, 𝒟y,
...@@ -112,260 +109,45 @@ function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_nam ...@@ -112,260 +109,45 @@ function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_nam
end end
if _retrain if _retrain
model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in models) model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in models)
Serialization.serialize(joinpath(output_path,"mnist_models.jls"), model_dict) Serialization.serialize(joinpath(output_path,"cal_housing_models.jls"), model_dict)
else else
model_dict = Serialization.deserialize(joinpath(output_path,"mnist_models.jls")) model_dict = Serialization.deserialize(joinpath(output_path,"cal_housing_models.jls"))
end end
``` ```
```{julia} ```{julia}
# Plot generated samples: # # Evaluate models:
n_regen = 150
if _regen # measure = Dict(
for (mod_name, mod) in model_dict # :f1score => multiclass_f1score,
if ECCCo._has_sampler(mod) # :acc => accuracy,
sampler = ECCCo._get_sampler(mod) # :precision => multiclass_precision
else # )
K = length(counterfactual_data.y_levels) # model_performance = DataFrame()
input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1)) # for (mod_name, mod) in model_dict
𝒟x = Uniform(extrema(counterfactual_data.X)...) # # Test performance:
𝒟y = Categorical(ones(K) ./ K) # test_data = load_mnist_test()
sampler = ConditionalSampler(𝒟x, 𝒟y; input_size=input_size) # _perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure)))
end # _perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
opt = ImproperSGLD() # _perf.mod_name .= mod_name
f(x) = logits(mod, x) # model_performance = vcat(model_performance, _perf)
# end
_w = 1500 # Serialization.serialize(joinpath(output_path,"cal_housing_model_performance.jls"), model_performance)
plts = [] # CSV.write(joinpath(output_path, "cal_housing_model_performance.csv"), model_performance)
neach = 10 # model_performance
for i in 1:10
x = sampler(f, opt; niter=n_regen, n_samples=neach, y=i)
plts_i = []
for j in 1:size(x, 2)
xj = x[:,j]
xj = reshape(xj, (n_digits, n_digits))
plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)]
end
plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))
plts = [plts..., plt]
end
plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1), plot_title=mod_name)
savefig(plt, joinpath(output_images_path, "mnist_generated_$(mod_name).png"))
display(plt)
end
end
``` ```
```{julia} ## Benchmark
# Evaluate models:
measure = Dict(
:f1score => multiclass_f1score,
:acc => accuracy,
:precision => multiclass_precision
)
model_performance = DataFrame()
for (mod_name, mod) in model_dict
# Test performance:
test_data = load_mnist_test()
_perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure)))
_perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
_perf.mod_name .= mod_name
model_performance = vcat(model_performance, _perf)
end
Serialization.serialize(joinpath(output_path,"mnist_model_performance.jls"), model_performance)
CSV.write(joinpath(output_path, "mnist_model_performance.csv"), model_performance)
model_performance
```
### Different Models
```{julia}
function _plot_eccco_mnist(
x::Union{AbstractArray, Int}=x_factual, target::Int=target;
λ=[0.5,0.1,0.5],
temp=0.1,η=0.01,
plt_order = ["MLP", "MLP Ensemble", "JEM", "JEM Ensemble"],
opt = Flux.Optimise.Adam(η),
rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG,
)
# Setup:
Random.seed!(rng)
if x isa Int
x = reshape(counterfactual_data.X[:,rand(findall(labels.==x))],input_dim,1)
end
# Generate counterfactuals using ECCCo generator:
eccco_generator = ECCCoGenerator(
λ=λ,
temp=temp,
opt=opt,
)
ces = Dict()
for (mod_name, mod) in model_dict
ce = generate_counterfactual(
x, target, counterfactual_data, mod, eccco_generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
)
ces[mod_name] = ce
end
_plt_order = map(x -> findall(collect(keys(model_dict)) .== x)[1], plt_order)
# Plot:
p1 = Plots.plot(
convert2image(MNIST, reshape(x,28,28)),
axis=nothing,
size=(img_height, img_height),
title="Factual"
)
plts = []
for (_name,ce) in ces
_x = CounterfactualExplanations.counterfactual(ce)
_phat = target_probs(ce)
_title = "$_name (p̂=$(round(_phat[1]; digits=3)))"
plt = Plots.plot(
convert2image(MNIST, reshape(_x,28,28)),
axis=nothing,
size=(img_height, img_height),
title=_title
)
plts = [plts..., plt]
end
plts = plts[_plt_order]
plts = [p1, plts...]
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
return plt, eccco_generator
end
```
```{julia}
plt, eccco_generator = _plot_eccco_mnist()
display(plt)
savefig(plt, joinpath(output_images_path, "mnist_eccco.png"))
```
### All digits
```{julia}
function plot_mnist(
factual::Int,target::Int;
generator::AbstractGenerator,
model::AbstractFittedModel=model_dict["JEM Ensemble"],
data::CounterfactualData=counterfactual_data,
rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG,
_plot_title::Bool=true,
kwargs...,
)
decision_threshold = !isdefined(kwargs, :decision_threshold) ? 0.9 : decision_threshold
max_iter = !isdefined(kwargs, :max_iter) ? 100 : max_iter
initialization = !isdefined(kwargs, :initialization) ? :identity : initialization
converge_when = !isdefined(kwargs, :converge_when) ? :generator_conditions : converge_when
x = reshape(data.X[:,rand(findall(predict_label(model, data).==factual))],input_dim,1)
ce = generate_counterfactual(
x, target, data, model, generator;
decision_threshold=decision_threshold, max_iter=max_iter,
initialization=initialization,
converge_when=converge_when,
kwargs...
)
_title = _plot_title ? "$(factual) -> $(target)" : ""
_x = CounterfactualExplanations.counterfactual(ce)
plt = Plots.plot(
convert2image(MNIST, reshape(_x,28,28)),
axis=nothing,
size=(img_height, img_height),
title=_title
)
return plt
end
```
```{julia}
if _regen
function plot_all_digits(rng=1;verbose=true,kwargs...)
plts = []
for i in 0:9
for j in 0:9
@info "Generating counterfactual for $(i) -> $(j)"
plt = plot_mnist(i,j;kwargs...,rng=rng)
!verbose || display(plt)
plts = [plts..., plt]
end
end
plt = Plots.plot(plts...; size=(img_height*10,img_height*10), layout=(10,10))
return plt
end
plt = plot_all_digits(generator=eccco_generator)
savefig(plt, joinpath(output_images_path, "mnist_eccco_all_digits.png"))
end
```
### Different Generators
```{julia} ```{julia}
# Setup:
model = model_dict["JEM Ensemble"]
# Benchmark generators: # Benchmark generators:
generator_dict = Dict( generator_dict = Dict(
:wachter => generic_generator, :wachter => WachterGenerator(),
:revise => revise_generator, :revise => REVISEGenerator(),
:greedy => greedy_generator, :greedy => GreedyGenerator(),
:eccco => eccco_generator, :eccco => ECCCoGenerator(),
)
ces = Dict()
for (gen_name, gen) in generator_dict
ce = generate_counterfactual(
x_factual, target, counterfactual_data, model, gen;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
)
ces[gen_name] = ce
end
plt_order = sortperm(collect(keys(ces)))
# Plot:
p1 = Plots.plot(
convert2image(MNIST, reshape(x_factual,28,28)),
axis=nothing,
size=(img_height, img_height),
title="Factual"
) )
plts = []
for (_name,ce) in ces
_x = CounterfactualExplanations.counterfactual(ce)
_phat = target_probs(ce)
_title = "$_name (p̂=$(round(_phat[1]; digits=3)))"
plt = Plots.plot(
convert2image(MNIST, reshape(_x,28,28)),
axis=nothing,
size=(img_height, img_height),
title=_title
)
plts = [plts..., plt]
end
plts = plts[plt_order]
plts = [p1, plts...]
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
display(plt)
savefig(plt, joinpath(output_images_path, "mnist_all_generators.png"))
```
## Benchmark
```{julia}
# Measures: # Measures:
measures = [ measures = [
CounterfactualExplanations.distance, CounterfactualExplanations.distance,
...@@ -380,12 +162,12 @@ bmk = benchmark( ...@@ -380,12 +162,12 @@ bmk = benchmark(
models=model_dict, models=model_dict,
generators=generator_dict, generators=generator_dict,
measure=measures, measure=measures,
suppress_training=true, dataname="MNIST", suppress_training=true, dataname="California Housing",
n_individuals=5, n_individuals=5,
factual=0, target=1, factual=0, target=1,
initialization=:identity, initialization=:identity,
) )
CSV.write(joinpath(output_path, "mnist_benchmark.csv"), bmk()) CSV.write(joinpath(output_path, "cal_housing_benchmark.csv"), bmk())
``` ```
...@@ -416,5 +198,5 @@ plt = draw( ...@@ -416,5 +198,5 @@ plt = draw(
facet=(; linkyaxes=:minimal) facet=(; linkyaxes=:minimal)
) )
display(plt) display(plt)
save(joinpath(output_images_path, "mnist_benchmark.png"), plt, px_per_unit=5) save(joinpath(output_images_path, "cal_housing_benchmark.png"), plt, px_per_unit=5)
``` ```
\ No newline at end of file
```{julia}
include("notebooks/setup.jl")
eval(setup_notebooks)
```
# Real-World Data
```{julia}
# Hyper:
_retrain = true
# Data:
n_obs = 10000
counterfactual_data = load_california_housing(n_obs)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X)
output_dim = length(unique(labels))
```
First, let's create a couple of image classifier architectures:
```{julia}
# Model parameters:
epochs = 100
batch_size = minimum([Int(round(n_obs/10)), 128])
n_hidden = 64
activation = Flux.relu
builder = MLJFlux.@builder Flux.Chain(
Dense(n_in, n_hidden, activation),
Dense(n_hidden, n_out),
)
n_ens = 5 # number of models in ensemble
_loss = Flux.Losses.logitcrossentropy # loss function
_finaliser = x -> x # finaliser function
```
```{julia}
# JEM parameters:
𝒟x = Normal()
𝒟y = Categorical(ones(output_dim) ./ output_dim)
sampler = ConditionalSampler(
𝒟x, 𝒟y,
input_size=(input_dim,),
batch_size=10,
)
α = [1.0,1.0,1e-2] # penalty strengths
```
```{julia}
# Simple MLP:
mlp = NeuralNetworkClassifier(
builder=builder,
epochs=epochs,
batch_size=batch_size,
finaliser=_finaliser,
loss=_loss,
)
# Deep Ensemble:
mlp_ens = EnsembleModel(model=mlp, n=n_ens)
# Joint Energy Model:
jem = JointEnergyClassifier(
sampler;
builder=builder,
epochs=epochs,
batch_size=batch_size,
finaliser=_finaliser,
loss=_loss,
jem_training_params=(
α=α,verbosity=10,
),
sampling_steps=20,
)
# JEM with adversarial training:
jem_adv = deepcopy(jem)
# jem_adv.adv_training = true
# Deep Ensemble of Joint Energy Models:
jem_ens = EnsembleModel(model=jem, n=n_ens)
# Deep Ensemble of Joint Energy Models with adversarial training:
# jem_ens_plus = EnsembleModel(model=jem_adv, n=n_ens)
# Dictionary of models:
models = Dict(
"MLP" => mlp,
"MLP Ensemble" => mlp_ens,
"JEM" => jem,
"JEM Ensemble" => jem_ens,
# "JEM Ensemble+" => jem_ens_plus,
)
```
```{julia}
# Train models:
function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_name="model")
conf_model = conformal_model(model; method=method, coverage=cov)
mach = machine(conf_model, X, y)
@info "Begin training $mod_name."
fit!(mach)
@info "Finished training $mod_name."
M = ECCCo.ConformalModel(mach.model, mach.fitresult)
return M
end
if _retrain
model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in models)
Serialization.serialize(joinpath(output_path,"gmsc_models.jls"), model_dict)
else
model_dict = Serialization.deserialize(joinpath(output_path,"gmsc_models.jls"))
end
```
```{julia}
# # Evaluate models:
# measure = Dict(
# :f1score => multiclass_f1score,
# :acc => accuracy,
# :precision => multiclass_precision
# )
# model_performance = DataFrame()
# for (mod_name, mod) in model_dict
# # Test performance:
# test_data = load_mnist_test()
# _perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure)))
# _perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
# _perf.mod_name .= mod_name
# model_performance = vcat(model_performance, _perf)
# end
# Serialization.serialize(joinpath(output_path,"gmsc_model_performance.jls"), model_performance)
# CSV.write(joinpath(output_path, "gmsc_model_performance.csv"), model_performance)
# model_performance
```
## Benchmark
```{julia}
# Benchmark generators:
generator_dict = Dict(
:wachter => WachterGenerator(),
:revise => REVISEGenerator(),
:greedy => GreedyGenerator(),
:eccco => ECCCoGenerator(),
)
# Measures:
measures = [
CounterfactualExplanations.distance,
ECCCo.distance_from_energy,
ECCCo.distance_from_targets,
CounterfactualExplanations.Evaluation.validity,
CounterfactualExplanations.Evaluation.redundancy,
]
bmk = benchmark(
counterfactual_data;
models=model_dict,
generators=generator_dict,
measure=measures,
suppress_training=true, dataname="Californian Housing",
n_individuals=5,
factual=0, target=1,
initialization=:identity,
)
CSV.write(joinpath(output_path, "gmsc_benchmark.csv"), bmk())
```
```{julia}
@chain bmk() begin
@group_by(dataname, generator, model, variable)
@summarize(mean=mean(value),sd=std(value))
@ungroup
@filter(variable == "distance_from_energy")
end
```
```{julia}
df = @chain bmk() begin
@filter(variable in [
"distance_from_energy",
"distance_from_targets",
"distance",])
@mutate(variable = ifelse.(variable .== "distance_from_energy", "Non-Conformity", variable))
@mutate(variable = ifelse.(variable .== "distance_from_targets", "Implausibility", variable))
@mutate(variable = ifelse.(variable .== "distance", "Cost", variable))
end
plt = AlgebraOfGraphics.data(df) * visual(BoxPlot) *
mapping(:generator, :value, row=:variable, col=:model, color=:generator)
plt = draw(
plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=150, height=120),
facet=(; linkyaxes=:minimal)
)
display(plt)
save(joinpath(output_images_path, "gmsc_benchmark.png"), plt, px_per_unit=5)
```
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment