diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 433e3ba30df6509ee002fed887ba04f5fa7bef51..2582057672d30489bc1b66ae9032965004b129f6 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -25,8 +25,21 @@ epochs = 100 batch_size = Int(round(n_obs/10)) n_hidden = 32 activation = Flux.swish -builder = MLJFlux.MLP(hidden=(n_hidden,), σ=activation) -α = [0.33,1.0,1e-1] +builder = MLJFlux.@builder Flux.Chain( + Dense(n_in, n_hidden), + BatchNorm(n_hidden, activation), + Dense(n_hidden, n_out), + BatchNorm(n_out) +) +# builder = MLJFlux.MLP( +# hidden=( +# n_hidden, +# n_hidden, +# n_hidden, +# ), +# σ=activation +# ) +α = [1.0,1.0,1e-1] # Simple MLP: mlp = NeuralNetworkClassifier( @@ -36,17 +49,18 @@ mlp = NeuralNetworkClassifier( ) # Joint Energy Model: -ð’Ÿx = Uniform(-1,1) +ð’Ÿx = Uniform(0,1) ð’Ÿy = Categorical(ones(output_dim) ./ output_dim) sampler = ConditionalSampler(ð’Ÿx, ð’Ÿy, input_size=(input_dim,), batch_size=batch_size) jem = JointEnergyClassifier( sampler; builder=builder, batch_size=batch_size, - finaliser=Flux.softmax, - loss=Flux.Losses.crossentropy, + finaliser=x -> x, + loss=Flux.Losses.logitcrossentropy, jem_training_params=(α=α,verbosity=10,), sampling_steps=20, + epochs=epochs, ) # Deep Ensemble: @@ -54,8 +68,8 @@ mlp_ens = EnsembleModel(model=mlp, n=5) ``` ```{julia} -cov = .9 -conf_model = conformal_model(jem; method=:adaptive_inductive, coverage=cov) +cov = .90 +conf_model = conformal_model(jem; method=:simple_inductive, coverage=cov) mach = machine(conf_model, X, labels) fit!(mach) M = CCE.ConformalModel(mach.model, mach.fitresult) @@ -63,7 +77,7 @@ M = CCE.ConformalModel(mach.model, mach.fitresult) ```{julia} jem = mach.model.model.jem -n_iter = 5000 +n_iter = 100 _w = 1500 plts = [] neach = 10 @@ -78,7 +92,8 @@ for i in 1:10 plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10)) plts = [plts..., plt] end -Plots.plot(plts..., size=(_w,_w), layout=(10,1)) +plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1)) +display(plt) ``` ```{julia} @@ -88,12 +103,14 @@ println("F1 score (test): $(round(f1,digits=3))") ``` ```{julia} +Random.seed!(1234) + # Set up search: -factual_label = 9 +factual_label = 2 x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) -target = 4 +target = 0 factual = predict_label(M, counterfactual_data, x)[1] -γ = 0.9 +γ = 0.5 T = 100 # Generate counterfactual using generic generator: @@ -106,9 +123,9 @@ ce_wachter = generate_counterfactual( # Generate counterfactual using CCE generator: generator = CCEGenerator( - λ=[0.0,10.0], - temp=0.01, - # opt=CounterfactualExplanations.Generators.JSMADescent(η=0.5), + λ=[0.0,1.0], + temp=0.5, + # opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0), ) ce_conformal = generate_counterfactual( x, target, counterfactual_data, M, generator; diff --git a/src/model.jl b/src/model.jl index 2d3ec05df4434394173ee7f9ec781331e1a2200d..c724867b51e2d03c696b991d1d708766525d1cfb 100644 --- a/src/model.jl +++ b/src/model.jl @@ -86,7 +86,11 @@ function Models.logits(M::ConformalModel, X::AbstractArray) pÌ‚ = [pÌ‚] end pÌ‚ = reduce(hcat, pÌ‚) - yÌ‚ = reduce(hcat, (map(p -> log.(p) .+ log(sum(exp.(p))), eachcol(pÌ‚)))) + if all(0.0 .<= vec(pÌ‚) .<= 1.0) + yÌ‚ = reduce(hcat, (map(p -> log.(p) .+ log(sum(exp.(p))), eachcol(pÌ‚)))) + else + yÌ‚ = pÌ‚ + end if M.likelihood == :classification_binary yÌ‚ = reduce(hcat, (map(y -> y[2] - y[1], eachcol(yÌ‚)))) end diff --git a/www/cce_mnist.png b/www/cce_mnist.png index 8af78f6df7b2364c9d6883f4b5dea1046c1ac173..3706287ac1d5a3ee59587e8336bb1e687f5a6286 100644 Binary files a/www/cce_mnist.png and b/www/cce_mnist.png differ