Skip to content
Snippets Groups Projects
Commit 2a95342c authored by Pat Alt's avatar Pat Alt
Browse files

support for weighted outcomes

parent d3ebe58d
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
......@@ -78,12 +78,23 @@ const ECCCo_Δ_NAMES = [
Returns the best outcome from grid search results. The best outcome is defined as the one with the lowest average rank across all datasets and variables for the specified generator and measure.
"""
function best_outcome(outcomes::Dict; generator=ALL_ECCCO_NAMES, measure=["distance_from_energy_l2", "distance_from_targets_l2"], model::Union{Nothing,AbstractArray}=nothing)
function best_outcome(
outcomes::Dict;
generator=ALL_ECCCO_NAMES,
measure=["distance_from_energy_l2", "distance_from_targets_l2"],
model::Union{Nothing,AbstractArray}=nothing,
weights::Union{Nothing,AbstractArray}=nothing
)
weights = isnothing(weights) ? ones(length(measure)) : weights
df_weights = DataFrame(variable=measure, weight=weights)
ranks = []
for (params, outcome) in outcomes
_ranks = generator_rank(outcome; generator=generator, measure=measure, model=model) |>
x -> x.avg_rank |>
x -> (sum(x) / length(x))[1]
x -> leftjoin(x, df_weights, on=:variable) |>
x -> x.avg_rank .* x.weight |>
x -> (sum(x) / length(x))[1]
push!(ranks, _ranks)
end
best_index = argmin(ranks)
......@@ -94,9 +105,9 @@ function best_outcome(outcomes::Dict; generator=ALL_ECCCO_NAMES, measure=["dista
return best_outcome
end
best_eccco(outcomes) = best_outcome(outcomes; generator=ECCCO_NAMES)
best_eccco(outcomes; kwrgs...) = best_outcome(outcomes; generator=ECCCO_NAMES, kwrgs...)
best_eccco_Δ(outcomes) = best_outcome(outcomes; generator=ECCCo_Δ_NAMES)
best_eccco_Δ(outcomes; kwrgs...) = best_outcome(outcomes; generator=ECCCo_Δ_NAMES, kwrgs...)
"""
best_absolute_outcome(outcomes; generator=ECCCO_NAMES, measure="distance_from_energy")
......@@ -125,9 +136,9 @@ function best_absolute_outcome(outcomes::Dict; generator=ECCCO_NAMES, measure::S
)
end
best_absolute_outcome_eccco(outcomes) = best_absolute_outcome(outcomes; generator=ECCCO_NAMES)
best_absolute_outcome_eccco(outcomes; kwrgs...) = best_absolute_outcome(outcomes; generator=ECCCO_NAMES, kwrgs...)
best_absolute_outcome_eccco_Δ(outcomes) = best_absolute_outcome(outcomes; generator=ECCCo_Δ_NAMES)
best_absolute_outcome_eccco_Δ(outcomes; kwrgs...) = best_absolute_outcome(outcomes; generator=ECCCo_Δ_NAMES, kwrgs...)
"""
append_best_params!(params::NamedTuple, dataname::String)
......
......@@ -3,6 +3,8 @@ dataname = "MNIST"
n_obs = 10000
counterfactual_data = load_mnist(n_obs)
counterfactual_data.X = ECCCo.pre_process.(counterfactual_data.X)
# Adjust domain constraints to account for noise added during pre-processing:
counterfactual_data.domain = fill((minimum(counterfactual_data.X), maximum(counterfactual_data.X)), size(counterfactual_data.X, 1))
# VAE (trained on full dataset):
using CounterfactualExplanations.Models: load_mnist_vae
......@@ -44,7 +46,7 @@ params = (
niter_eccco=10,
Λ=[0.1, 0.25, 0.25],
Λ_Δ=[0.1, 0.1, 1.0],
opt=Flux.Optimise.Descent(0.1),
opt=Flux.Optimise.Descent(0.25),
reg_strength = 0.01,
ce_measures=ce_measures,
)
......
......@@ -25,6 +25,7 @@ using MLJEnsembles
using MLJFlux
using Random
using Serialization
using Statistics
import MPI
......@@ -158,7 +159,6 @@ DEFAULT_GENERATOR_TUNING = (
Flux.Optimise.Descent(0.1),
Flux.Optimise.Descent(0.05),
Flux.Optimise.Descent(0.01),
Flux.Optimise.Descent(0.001),
],
)
......
......@@ -58,5 +58,5 @@ Computes 1-SSIM between two images.
function ssim_dist(x, y)
x = convert2mnist(x)
y = convert2mnist(y)
return 1 - assess_ssim(x, y)
return (1 - assess_ssim(x, y))/2
end
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment