diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 1ad83cf96f8e0f5f54347ef4590d9193307a04c2..0716c8bcfe509a95540812fd7f3907628b7c6783 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -29,7 +29,6 @@ builder = MLJFlux.@builder Flux.Chain( Dense(n_in, n_hidden), BatchNorm(n_hidden, activation), Dense(n_hidden, n_out), - BatchNorm(n_out) ) # builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.2, σ=activation) # builder = MLJFlux.MLP( @@ -40,7 +39,7 @@ builder = MLJFlux.@builder Flux.Chain( # ), # σ=activation # ) -α = [1.0,1.0,5e-1] +α = [1.0,1.0,1e-2] # Simple MLP: mlp = NeuralNetworkClassifier( @@ -61,8 +60,8 @@ jem = JointEnergyClassifier( loss=Flux.Losses.logitcrossentropy, jem_training_params=( α=α,verbosity=10, - use_gen_loss=false, - use_reg_loss=false, + # use_gen_loss=false, + # use_reg_loss=false, ), sampling_steps=20, epochs=epochs,