diff --git a/notebooks/conformal.qmd b/notebooks/conformal.qmd index 015fbe0d03bf0a73a962ea731533f0efc4d19990..1dfdff9849fe6132a4c23876bee4bd30000d5c52 100644 --- a/notebooks/conformal.qmd +++ b/notebooks/conformal.qmd @@ -5,6 +5,7 @@ using CounterfactualExplanations using CounterfactualExplanations.Data using CounterfactualExplanations.Objectives using Flux +using LinearAlgebra using MLJBase using MLJFlux using Plots @@ -102,9 +103,9 @@ is a smooth size penalty for conformal classifiers introduced by @stutz2022learn #| echo: false #| label: fig-losses #| fig-cap: "Illustration of the smooth size loss and the configurable classification loss." - -p1 = contourf(mach.model, mach.fitresult, X, y; plot_set_loss=true, zoom=0) -p2 = contourf(mach.model, mach.fitresult, X, y; plot_classification_loss=true, target=target, zoom=0) +temp = 0.5 +p1 = contourf(mach.model, mach.fitresult, X, y; plot_set_loss=true, zoom=0, temp=temp) +p2 = contourf(mach.model, mach.fitresult, X, y; plot_classification_loss=true, target=target, zoom=0, temp=temp, clim=nothing, loss_matrix=ones(2,2)) plot(p1, p2, size=(800,320)) ``` @@ -157,6 +158,7 @@ plot(plts..., size=(_n * img_size,1.05*img_size), layout=(1,_n)) #| fig-cap: "Comparison of counterfactuals produced using different generators." opt = Descent(0.01) +temp = 0.75 ordered_names = [ "Generic (γ=0.5)", "Generic (γ=0.9)", @@ -169,8 +171,8 @@ loss_fun = Objectives.logitbinarycrossentropy generators = Dict( ordered_names[1] => GenericGenerator(opt = opt, decision_threshold=0.5), ordered_names[2] => GenericGenerator(opt = opt, decision_threshold=0.9), - ordered_names[3] => CCE.ConformalGenerator(opt=opt, λ=[0.1,1]), - ordered_names[4] => CCE.ConformalGenerator(opt=opt, λ=[0.1,10]), + ordered_names[3] => CCE.ConformalGenerator(opt=opt, λ=[0.1,1], temp=temp), + ordered_names[4] => CCE.ConformalGenerator(opt=opt, λ=[0.1,10], temp=temp), ) counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen; initialization=:identity) for (name, gen) in generators]) diff --git a/src/ConformalGenerator.jl b/src/ConformalGenerator.jl index 280e483c84c5025af68d0656ca524a7ee1c12585..b169399b47a65d4fe0fff8c35087bf68c990017a 100644 --- a/src/ConformalGenerator.jl +++ b/src/ConformalGenerator.jl @@ -63,18 +63,21 @@ A configurable classification loss function for Conformal Predictors. function conformal_training_loss(counterfactual_explanation::AbstractCounterfactualExplanation; kwargs...) conf_model = counterfactual_explanation.M.model fitresult = counterfactual_explanation.M.fitresult + generator = counterfactual_explanation.generator + temp = hasfield(typeof(generator), :temp) ? generator.temp : nothing + K = length(counterfactual_explanation.data.y_levels) X = CounterfactualExplanations.decode_state(counterfactual_explanation) y = counterfactual_explanation.target_encoded[:,:,1] if counterfactual_explanation.M.likelihood == :classification_binary y = binary_to_onehot(y) end y = permutedims(y) - generator = counterfactual_explanation.generator loss = SliceMap.slicemap(X, dims=(1, 2)) do x - x = Matrix(x) + x = Matrix(x) ConformalPrediction.classification_loss( conf_model, fitresult, x, y; - temp=generator.temp + temp=temp, + loss_matrix=Float32.(ones(K,K)) ) end loss = mean(loss)