diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml index edccb799635413a5efd692ffbc26e1aa806f8fcb..c0afb2087ea078b7fbf8631438e5f0ff39b87485 100644 --- a/notebooks/Manifest.toml +++ b/notebooks/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.5" manifest_format = "2.0" -project_hash = "512d9080e47cd18c9e3a640716f9947dd8512bcb" +project_hash = "70fd26ae44f12d0456544e252f519edafee7b553" [[deps.AbstractFFTs]] deps = ["ChainRulesCore", "LinearAlgebra"] diff --git a/notebooks/Project.toml b/notebooks/Project.toml index 538c28795473ee6ed358add88b6a4a47de8d10cb..614baa9d9bb2db96651f20a59af9ab5115217362 100644 --- a/notebooks/Project.toml +++ b/notebooks/Project.toml @@ -21,4 +21,5 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Tidier = "f0413319-3358-4bb0-8e7c-0c83523a93bd" diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index c04ca57ef70b11f8d4239ff7fc307fdb25d28318..7a1a41dd51db4beff52d2b9a18c46844b4ed350e 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -7,7 +7,7 @@ eval(setup_notebooks) ```{julia} # Data: -counterfactual_data = load_mnist(1000) +counterfactual_data = load_mnist() X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) X = table(permutedims(X)) labels = counterfactual_data.output_encoder.labels @@ -16,7 +16,11 @@ input_dim, n_obs = size(counterfactual_data.X) ```{julia} epochs = 100 -clf = NeuralNetworkClassifier(builder=MLJFlux.MLP(hidden=(32,), σ=relu), epochs=epochs) +clf = NeuralNetworkClassifier( + builder=MLJFlux.MLP(hidden=(32,), σ=relu), + epochs=epochs, + batch_size=Int(round(n_obs/10)) +) conf_model = conformal_model(clf; method=:simple_inductive) mach = machine(conf_model, X, labels) fit!(mach) @@ -32,11 +36,17 @@ f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data) println("F1 score (test): $(round(f1,digits=3))") ``` +```{julia} +using CounterfactualExplanations.DataPreprocessing: undersample +dt_reduced = undersample(counterfactual_data, 100) +# dt_reduced = counterfactual_data +``` + ```{julia} # Set up search: -factual_label = 8 +factual_label = 9 x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) -target = 3 +target = 4 factual = predict_label(M, counterfactual_data, x)[1] γ = 0.9 T = 100 @@ -44,15 +54,19 @@ T = 100 # Generate counterfactual using generic generator: generator = GenericGenerator() ce_wachter = generate_counterfactual( - x, target, counterfactual_data, M, generator; + x, target, dt_reduced, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, ) # Generate counterfactual using CCE generator: -generator = CCEGenerator(λ=[0.0,10.0], temp=0.01, opt=CounterfactualExplanations.Generators.JSMADescent(η=5.0)) +generator = CCEGenerator( + λ=[0.0,10.0], + temp=0.01, + # opt=CounterfactualExplanations.Generators.JSMADescent(η=5.0), +) ce_conformal = generate_counterfactual( - x, target, counterfactual_data, M, generator; + x, target, dt_reduced, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, converge_when=:generator_conditions, diff --git a/www/cce_mnist.png b/www/cce_mnist.png index 85eed69096b075c1519480191c1693cf0ceaf950..6de563ed4b67f3c267336a85556098625edc0b55 100644 Binary files a/www/cce_mnist.png and b/www/cce_mnist.png differ