Skip to content
Snippets Groups Projects
Commit 9da397e2 authored by pat-alt's avatar pat-alt
Browse files

work on mnist

parent 04174b2a
No related branches found
No related tags found
No related merge requests found
```{julia}
include("notebooks/setup.jl")
eval(setup_notebooks)
```
# MNIST
```{julia}
# Data:
counterfactual_data = load_mnist()
counterfactual_data = load_mnist(1000)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X)
M = load_mnist_mlp()
# Target:
factual_label = 8
x = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = 3
factual = predict_label(M, counterfactual_data, x)[1]
γ = 0.9
T = 50
```
```{julia}
builder = MLJFlux.@builder M.model
clf = NeuralNetworkClassifier(builder=builder, epochs=100)
epochs = 100
clf = NeuralNetworkClassifier(builder=MLJFlux.MLP(hidden=(32,), σ=relu), epochs=epochs)
conf_model = conformal_model(clf; method=:simple_inductive)
mach = machine(conf_model, X, labels)
fit!(mach)
```
```{julia}
# Search:
M = CCE.ConformalModel(mach.model, mach.fitresult)
```
```{julia}
test_data = load_mnist_test()
f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data)
println("F1 score (test): $(round(f1,digits=3))")
```
```{julia}
# Set up search:
factual_label = 8
x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = 3
factual = predict_label(M, counterfactual_data, x)[1]
γ = 0.9
T = 100
# Generate counterfactual using generic generator:
generator = GenericGenerator()
ce_wachter = generate_counterfactual(
x, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
)
generator = deepcopy(generator) |> gen -> @objective(gen, _ + 0.001distance + 1.0set_size_penalty)
# Generate counterfactual using CCE generator:
generator = CCEGenerator(λ=[0.0,10.0], temp=0.01, opt=CounterfactualExplanations.Generators.JSMADescent(η=5.0))
ce_conformal = generate_counterfactual(
x, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
)
```
```{julia}
p1 = plot(
p1 = Plots.plot(
convert2image(MNIST, reshape(x,28,28)),
axis=nothing,
size=(img_height, img_height),
......@@ -55,13 +68,13 @@ p1 = plot(
)
plts = [p1]
ces = zip([ce_wachter,ce_jsma])
ces = zip([ce_wachter,ce_conformal])
counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces))
phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces))
for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3))
ce, _phat = (x[1],x[2])
_title = "p(y=$(target)|x′)=$(round(_phat[1]; digits=3))"
plt = plot(
plt = Plots.plot(
convert2image(MNIST, reshape(ce,28,28)),
axis=nothing,
size=(img_height, img_height),
......@@ -69,6 +82,7 @@ for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3))
)
plts = [plts..., plt]
end
plt = plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
savefig(plt, joinpath(www_path, "you_may_not_like_it.png"))
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
display(plt)
savefig(plt, joinpath(www_path, "cce_mnist.png"))
```
\ No newline at end of file
www/cce_mnist.png

13 KiB

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment