diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 2e26b2203b5d1a2b893e38b5b82add2c8ceed0be..9b29c9960d7e9971fe152c6b6cf697a38f2911ed 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -21,7 +21,7 @@ clf = NeuralNetworkClassifier( epochs=epochs, batch_size=Int(round(n_obs/10)) ) -conf_model = conformal_model(clf; method=:simple_inductive, coverage=.99) +conf_model = conformal_model(clf; method=:adaptive_inductive, coverage=.99) mach = machine(conf_model, X, labels) fit!(mach) ``` @@ -44,9 +44,9 @@ dt_reduced = counterfactual_data ```{julia} # Set up search: -factual_label = 9 +factual_label = 8 x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) -target = 4 +target = 3 factual = predict_label(M, counterfactual_data, x)[1] γ = 0.9 T = 100 @@ -61,7 +61,7 @@ ce_wachter = generate_counterfactual( # Generate counterfactual using CCE generator: generator = CCEGenerator( - λ=[0.0,10.0], + λ=[0.0,100.0], temp=0.01, # opt=CounterfactualExplanations.Generators.JSMADescent(η=5.0), ) @@ -71,9 +71,8 @@ ce_conformal = generate_counterfactual( initialization=:identity, converge_when=:generator_conditions, ) -``` -```{julia} +# Plot: p1 = Plots.plot( convert2image(MNIST, reshape(x,28,28)), axis=nothing, diff --git a/www/cce_mnist.png b/www/cce_mnist.png index 100266f6f837ace4fedc8fc475fa6c04909501bf..1fdc26992789f5fc8ee7cd93a4039db059e22b2a 100644 Binary files a/www/cce_mnist.png and b/www/cce_mnist.png differ