Skip to content
Snippets Groups Projects
Commit 1dd40f47 authored by Pat Alt's avatar Pat Alt
Browse files

:sob

parent e0154ef6
No related branches found
No related tags found
No related merge requests found
......@@ -79,7 +79,7 @@ mlp_ens = EnsembleModel(model=mlp, n=5)
```{julia}
cov = .90
conf_model = conformal_model(mlp; method=:adaptive_inductive, coverage=cov)
conf_model = conformal_model(jem; method=:adaptive_inductive, coverage=cov)
mach = machine(conf_model, X, labels)
fit!(mach)
M = CCE.ConformalModel(mach.model, mach.fitresult)
......@@ -115,32 +115,40 @@ println("F1 score (test): $(round(f1,digits=3))")
```
```{julia}
Random.seed!(1234)
# Random.seed!(1234)
# Set up search:
factual_label = 4
factual_label = 8
x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = 9
target = 3
factual = predict_label(M, counterfactual_data, x)[1]
γ = 0.5
T = 100
# Generate counterfactual using generic generator:
generator = GenericGenerator()
generator = GenericGenerator(opt=Flux.Optimise.Adam(),)
ce_wachter = generate_counterfactual(
x, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
)
generator = GreedyGenerator(η=1.0)
ce_jsma = generate_counterfactual(
x, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
)
# CCE:
λ=[0.0,1.0]
temp=0.5
temp=0.01
# Generate counterfactual using CCE generator:
generator = CCEGenerator(
λ=λ,
temp=temp,
opt=Flux.Optimise.Adam(),
)
ce_conformal = generate_counterfactual(
x, target, counterfactual_data, M, generator;
......@@ -171,15 +179,15 @@ p1 = Plots.plot(
)
plts = [p1]
ces = zip([ce_wachter,ce_conformal,ce_conformal_jsma])
_names = ["Wachter", "CCE", "CCE-JSMA"]
counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces))
phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces))
for x in zip(eachslice(counterfactuals; dims=3), _names, eachslice(phat; dims=3))
ce, _name, _phat = (x[1],x[2],x[3])
ces = [ce_wachter, ce_conformal, ce_jsma, ce_conformal_jsma]
_names = ["Wachter", "CCE", "JSMA", "CCE-JSMA"]
for x in zip(ces, _names)
ce, _name = (x[1],x[2])
x = CounterfactualExplanations.counterfactual(ce)
_phat = target_probs(ce)
_title = "$_name (p̂=$(round(_phat[1]; digits=3)))"
plt = Plots.plot(
convert2image(MNIST, reshape(ce,28,28)),
convert2image(MNIST, reshape(x,28,28)),
axis=nothing,
size=(img_height, img_height),
title=_title
......
www/cce_mnist.png

15.4 KiB | W: | H:

www/cce_mnist.png

21.3 KiB | W: | H:

www/cce_mnist.png
www/cce_mnist.png
www/cce_mnist.png
www/cce_mnist.png
  • 2-up
  • Swipe
  • Onion skin
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment