Skip to content
Snippets Groups Projects

73 aries comments

Merged Imported Patrick Altmeyer requested to merge 73-aries-comments into main
1 file
+ 13
18
Compare changes
  • Side-by-side
  • Inline
@@ -62,21 +62,15 @@ function plot_random_eccco(outcome::ExperimentOutcome; generator="ECCCo-Δ", img
# Get output:
bmk = outcome.bmk()
grouped_bmk = groupby(bmk[bmk.variable.=="distance" .&& bmk.generator.==generator, :], [:dataname, :target, :factual])
random_choice = rand(1:length(grouped_bmk))
models = unique(bmk.model)
n_models = length(models)
df = grouped_bmk[random_choice]
while nrow(df) > n_models
random_choice = rand(1:length(grouped_bmk))
df = grouped_bmk[random_choice]
end
sort!(df, :model)
models = df.model
ce = rand(bmk.ce)
gen = outcome.generator_dict[generator]
models = outcome.model_dict
x = CounterfactualExplanations.counterfactual(ce)
target = ce.target
data = ce.data
# Factual:
img = CounterfactualExplanations.factual(df.ce[1]) |> ECCCo.convert2mnist
img = CounterfactualExplanations.factual(ce) |> ECCCo.convert2mnist
p1 = Plots.plot(
img,
axis=([], false),
@@ -85,17 +79,18 @@ function plot_random_eccco(outcome::ExperimentOutcome; generator="ECCCo-Δ", img
)
plts = [p1]
# Counterfactuals:
for (i, model) in enumerate(models)
img = CounterfactualExplanations.counterfactual(df.ce[i]) |> ECCCo.convert2mnist
for (model_name, M) in models
ce = generate_counterfactual(x, target, data, M, gen; initialization=:identity, converge_when=:generator_conditions)
img = CounterfactualExplanations.counterfactual(ce) |> ECCCo.convert2mnist
p = Plots.plot(
img,
axis=([], false),
size=(img_height, img_height),
title="$model",
title="$model_name",
)
push!(plts, p)
end
n_models = length(models)
plt = Plots.plot(
plts...,
@@ -104,5 +99,5 @@ function plot_random_eccco(outcome::ExperimentOutcome; generator="ECCCo-Δ", img
)
display(plt)
return plt, df.target[1], seed
return plt, target, seed
end
\ No newline at end of file
Loading