Skip to content
Snippets Groups Projects
Commit 21201316 authored by Pat Alt's avatar Pat Alt
Browse files
parents 50054938 fab794a1
No related branches found
No related tags found
No related merge requests found
```{julia}
include("notebooks/setup.jl")
eval(setup_notebooks)
```
# MNIST
```{julia}
# Data:
counterfactual_data = load_mnist()
counterfactual_data = load_mnist(1000)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X)
M = load_mnist_mlp()
# Target:
factual_label = 8
x = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = 3
factual = predict_label(M, counterfactual_data, x)[1]
γ = 0.9
T = 50
```
```{julia}
builder = MLJFlux.@builder M.model
clf = NeuralNetworkClassifier(builder=builder, epochs=100)
epochs = 100
clf = NeuralNetworkClassifier(builder=MLJFlux.MLP(hidden=(32,), σ=relu), epochs=epochs)
conf_model = conformal_model(clf; method=:simple_inductive)
mach = machine(conf_model, X, labels)
fit!(mach)
```
```{julia}
# Search:
M = CCE.ConformalModel(mach.model, mach.fitresult)
```
```{julia}
test_data = load_mnist_test()
f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data)
println("F1 score (test): $(round(f1,digits=3))")
```
```{julia}
# Set up search:
factual_label = 8
x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = 3
factual = predict_label(M, counterfactual_data, x)[1]
γ = 0.9
T = 100
# Generate counterfactual using generic generator:
generator = GenericGenerator()
ce_wachter = generate_counterfactual(
x, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
)
generator = deepcopy(generator) |> gen -> @objective(gen, _ + 0.001distance + 1.0set_size_penalty)
# Generate counterfactual using CCE generator:
generator = CCEGenerator(λ=[0.0,10.0], temp=0.01, opt=CounterfactualExplanations.Generators.JSMADescent(η=5.0))
ce_conformal = generate_counterfactual(
x, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
)
```
```{julia}
p1 = plot(
p1 = Plots.plot(
convert2image(MNIST, reshape(x,28,28)),
axis=nothing,
size=(img_height, img_height),
......@@ -55,13 +68,13 @@ p1 = plot(
)
plts = [p1]
ces = zip([ce_wachter,ce_jsma])
ces = zip([ce_wachter,ce_conformal])
counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces))
phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces))
for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3))
ce, _phat = (x[1],x[2])
_title = "p(y=$(target)|x′)=$(round(_phat[1]; digits=3))"
plt = plot(
plt = Plots.plot(
convert2image(MNIST, reshape(ce,28,28)),
axis=nothing,
size=(img_height, img_height),
......@@ -69,6 +82,7 @@ for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3))
)
plts = [plts..., plt]
end
plt = plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
savefig(plt, joinpath(www_path, "you_may_not_like_it.png"))
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
display(plt)
savefig(plt, joinpath(www_path, "cce_mnist.png"))
```
\ No newline at end of file
www/cce_mnist.png

13 KiB

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment