Skip to content
Snippets Groups Projects
Commit 8bbb653b authored by Pat Alt's avatar Pat Alt
Browse files

ufff

parent 21201316
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
julia_version = "1.8.5" julia_version = "1.8.5"
manifest_format = "2.0" manifest_format = "2.0"
project_hash = "512d9080e47cd18c9e3a640716f9947dd8512bcb" project_hash = "70fd26ae44f12d0456544e252f519edafee7b553"
[[deps.AbstractFFTs]] [[deps.AbstractFFTs]]
deps = ["ChainRulesCore", "LinearAlgebra"] deps = ["ChainRulesCore", "LinearAlgebra"]
......
...@@ -21,4 +21,5 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" ...@@ -21,4 +21,5 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tidier = "f0413319-3358-4bb0-8e7c-0c83523a93bd" Tidier = "f0413319-3358-4bb0-8e7c-0c83523a93bd"
...@@ -7,7 +7,7 @@ eval(setup_notebooks) ...@@ -7,7 +7,7 @@ eval(setup_notebooks)
```{julia} ```{julia}
# Data: # Data:
counterfactual_data = load_mnist(1000) counterfactual_data = load_mnist()
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X)) X = table(permutedims(X))
labels = counterfactual_data.output_encoder.labels labels = counterfactual_data.output_encoder.labels
...@@ -16,7 +16,11 @@ input_dim, n_obs = size(counterfactual_data.X) ...@@ -16,7 +16,11 @@ input_dim, n_obs = size(counterfactual_data.X)
```{julia} ```{julia}
epochs = 100 epochs = 100
clf = NeuralNetworkClassifier(builder=MLJFlux.MLP(hidden=(32,), σ=relu), epochs=epochs) clf = NeuralNetworkClassifier(
builder=MLJFlux.MLP(hidden=(32,), σ=relu),
epochs=epochs,
batch_size=Int(round(n_obs/10))
)
conf_model = conformal_model(clf; method=:simple_inductive) conf_model = conformal_model(clf; method=:simple_inductive)
mach = machine(conf_model, X, labels) mach = machine(conf_model, X, labels)
fit!(mach) fit!(mach)
...@@ -32,11 +36,17 @@ f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data) ...@@ -32,11 +36,17 @@ f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data)
println("F1 score (test): $(round(f1,digits=3))") println("F1 score (test): $(round(f1,digits=3))")
``` ```
```{julia}
using CounterfactualExplanations.DataPreprocessing: undersample
dt_reduced = undersample(counterfactual_data, 100)
# dt_reduced = counterfactual_data
```
```{julia} ```{julia}
# Set up search: # Set up search:
factual_label = 8 factual_label = 9
x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = 3 target = 4
factual = predict_label(M, counterfactual_data, x)[1] factual = predict_label(M, counterfactual_data, x)[1]
γ = 0.9 γ = 0.9
T = 100 T = 100
...@@ -44,15 +54,19 @@ T = 100 ...@@ -44,15 +54,19 @@ T = 100
# Generate counterfactual using generic generator: # Generate counterfactual using generic generator:
generator = GenericGenerator() generator = GenericGenerator()
ce_wachter = generate_counterfactual( ce_wachter = generate_counterfactual(
x, target, counterfactual_data, M, generator; x, target, dt_reduced, M, generator;
decision_threshold=γ, max_iter=T, decision_threshold=γ, max_iter=T,
initialization=:identity, initialization=:identity,
) )
# Generate counterfactual using CCE generator: # Generate counterfactual using CCE generator:
generator = CCEGenerator(λ=[0.0,10.0], temp=0.01, opt=CounterfactualExplanations.Generators.JSMADescent(η=5.0)) generator = CCEGenerator(
λ=[0.0,10.0],
temp=0.01,
# opt=CounterfactualExplanations.Generators.JSMADescent(η=5.0),
)
ce_conformal = generate_counterfactual( ce_conformal = generate_counterfactual(
x, target, counterfactual_data, M, generator; x, target, dt_reduced, M, generator;
decision_threshold=γ, max_iter=T, decision_threshold=γ, max_iter=T,
initialization=:identity, initialization=:identity,
converge_when=:generator_conditions, converge_when=:generator_conditions,
......
www/cce_mnist.png

13 KiB | W: | H:

www/cce_mnist.png

12.8 KiB | W: | H:

www/cce_mnist.png
www/cce_mnist.png
www/cce_mnist.png
www/cce_mnist.png
  • 2-up
  • Swipe
  • Onion skin
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment