diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 2582057672d30489bc1b66ae9032965004b129f6..1ad83cf96f8e0f5f54347ef4590d9193307a04c2 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -7,7 +7,7 @@ eval(setup_notebooks) ```{julia} # Data: -n_obs = 1000 +n_obs = 10000 counterfactual_data = load_mnist(n_obs) X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) X = table(permutedims(X)) @@ -22,15 +22,16 @@ First, let's create a couple of image classifier architectures: ```{julia} # Model parameters: epochs = 100 -batch_size = Int(round(n_obs/10)) +batch_size = minimum([Int(round(n_obs/10)), 100]) n_hidden = 32 -activation = Flux.swish +activation = Flux.relu builder = MLJFlux.@builder Flux.Chain( Dense(n_in, n_hidden), BatchNorm(n_hidden, activation), Dense(n_hidden, n_out), BatchNorm(n_out) ) +# builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.2, σ=activation) # builder = MLJFlux.MLP( # hidden=( # n_hidden, @@ -39,7 +40,7 @@ builder = MLJFlux.@builder Flux.Chain( # ), # σ=activation # ) -α = [1.0,1.0,1e-1] +α = [1.0,1.0,5e-1] # Simple MLP: mlp = NeuralNetworkClassifier( @@ -58,7 +59,11 @@ jem = JointEnergyClassifier( batch_size=batch_size, finaliser=x -> x, loss=Flux.Losses.logitcrossentropy, - jem_training_params=(α=α,verbosity=10,), + jem_training_params=( + α=α,verbosity=10, + use_gen_loss=false, + use_reg_loss=false, + ), sampling_steps=20, epochs=epochs, ) @@ -106,12 +111,12 @@ println("F1 score (test): $(round(f1,digits=3))") Random.seed!(1234) # Set up search: -factual_label = 2 +factual_label = 9 x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) -target = 0 +target = 4 factual = predict_label(M, counterfactual_data, x)[1] γ = 0.5 -T = 100 +T = 1 # Generate counterfactual using generic generator: generator = GenericGenerator() diff --git a/www/cce_mnist.png b/www/cce_mnist.png index 3706287ac1d5a3ee59587e8336bb1e687f5a6286..ea956a80689729139b58966c11514539991a792a 100644 Binary files a/www/cce_mnist.png and b/www/cce_mnist.png differ