Skip to content
Snippets Groups Projects
Commit b4d0b633 authored by pat-alt's avatar pat-alt
Browse files

confused about domain of classification loss

parent 3a38fe8f
No related branches found
No related tags found
No related merge requests found
...@@ -5,6 +5,7 @@ using CounterfactualExplanations ...@@ -5,6 +5,7 @@ using CounterfactualExplanations
using CounterfactualExplanations.Data using CounterfactualExplanations.Data
using CounterfactualExplanations.Objectives using CounterfactualExplanations.Objectives
using Flux using Flux
using LinearAlgebra
using MLJBase using MLJBase
using MLJFlux using MLJFlux
using Plots using Plots
...@@ -102,9 +103,9 @@ is a smooth size penalty for conformal classifiers introduced by @stutz2022learn ...@@ -102,9 +103,9 @@ is a smooth size penalty for conformal classifiers introduced by @stutz2022learn
#| echo: false #| echo: false
#| label: fig-losses #| label: fig-losses
#| fig-cap: "Illustration of the smooth size loss and the configurable classification loss." #| fig-cap: "Illustration of the smooth size loss and the configurable classification loss."
temp = 0.5
p1 = contourf(mach.model, mach.fitresult, X, y; plot_set_loss=true, zoom=0) p1 = contourf(mach.model, mach.fitresult, X, y; plot_set_loss=true, zoom=0, temp=temp)
p2 = contourf(mach.model, mach.fitresult, X, y; plot_classification_loss=true, target=target, zoom=0) p2 = contourf(mach.model, mach.fitresult, X, y; plot_classification_loss=true, target=target, zoom=0, temp=temp, clim=nothing, loss_matrix=ones(2,2))
plot(p1, p2, size=(800,320)) plot(p1, p2, size=(800,320))
``` ```
...@@ -157,6 +158,7 @@ plot(plts..., size=(_n * img_size,1.05*img_size), layout=(1,_n)) ...@@ -157,6 +158,7 @@ plot(plts..., size=(_n * img_size,1.05*img_size), layout=(1,_n))
#| fig-cap: "Comparison of counterfactuals produced using different generators." #| fig-cap: "Comparison of counterfactuals produced using different generators."
opt = Descent(0.01) opt = Descent(0.01)
temp = 0.75
ordered_names = [ ordered_names = [
"Generic (γ=0.5)", "Generic (γ=0.5)",
"Generic (γ=0.9)", "Generic (γ=0.9)",
...@@ -169,8 +171,8 @@ loss_fun = Objectives.logitbinarycrossentropy ...@@ -169,8 +171,8 @@ loss_fun = Objectives.logitbinarycrossentropy
generators = Dict( generators = Dict(
ordered_names[1] => GenericGenerator(opt = opt, decision_threshold=0.5), ordered_names[1] => GenericGenerator(opt = opt, decision_threshold=0.5),
ordered_names[2] => GenericGenerator(opt = opt, decision_threshold=0.9), ordered_names[2] => GenericGenerator(opt = opt, decision_threshold=0.9),
ordered_names[3] => CCE.ConformalGenerator(opt=opt, λ=[0.1,1]), ordered_names[3] => CCE.ConformalGenerator(opt=opt, λ=[0.1,1], temp=temp),
ordered_names[4] => CCE.ConformalGenerator(opt=opt, λ=[0.1,10]), ordered_names[4] => CCE.ConformalGenerator(opt=opt, λ=[0.1,10], temp=temp),
) )
counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen; initialization=:identity) for (name, gen) in generators]) counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen; initialization=:identity) for (name, gen) in generators])
......
...@@ -63,18 +63,21 @@ A configurable classification loss function for Conformal Predictors. ...@@ -63,18 +63,21 @@ A configurable classification loss function for Conformal Predictors.
function conformal_training_loss(counterfactual_explanation::AbstractCounterfactualExplanation; kwargs...) function conformal_training_loss(counterfactual_explanation::AbstractCounterfactualExplanation; kwargs...)
conf_model = counterfactual_explanation.M.model conf_model = counterfactual_explanation.M.model
fitresult = counterfactual_explanation.M.fitresult fitresult = counterfactual_explanation.M.fitresult
generator = counterfactual_explanation.generator
temp = hasfield(typeof(generator), :temp) ? generator.temp : nothing
K = length(counterfactual_explanation.data.y_levels)
X = CounterfactualExplanations.decode_state(counterfactual_explanation) X = CounterfactualExplanations.decode_state(counterfactual_explanation)
y = counterfactual_explanation.target_encoded[:,:,1] y = counterfactual_explanation.target_encoded[:,:,1]
if counterfactual_explanation.M.likelihood == :classification_binary if counterfactual_explanation.M.likelihood == :classification_binary
y = binary_to_onehot(y) y = binary_to_onehot(y)
end end
y = permutedims(y) y = permutedims(y)
generator = counterfactual_explanation.generator
loss = SliceMap.slicemap(X, dims=(1, 2)) do x loss = SliceMap.slicemap(X, dims=(1, 2)) do x
x = Matrix(x) x = Matrix(x)
ConformalPrediction.classification_loss( ConformalPrediction.classification_loss(
conf_model, fitresult, x, y; conf_model, fitresult, x, y;
temp=generator.temp temp=temp,
loss_matrix=Float32.(ones(K,K))
) )
end end
loss = mean(loss) loss = mean(loss)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment