Skip to content
Snippets Groups Projects
Commit 302d344f authored by pat-alt's avatar pat-alt
Browse files

confused about domain of classification loss

parent 136467dc
No related branches found
No related tags found
No related merge requests found
......@@ -5,6 +5,7 @@ using CounterfactualExplanations
using CounterfactualExplanations.Data
using CounterfactualExplanations.Objectives
using Flux
using LinearAlgebra
using MLJBase
using MLJFlux
using Plots
......@@ -102,9 +103,9 @@ is a smooth size penalty for conformal classifiers introduced by @stutz2022learn
#| echo: false
#| label: fig-losses
#| fig-cap: "Illustration of the smooth size loss and the configurable classification loss."
p1 = contourf(mach.model, mach.fitresult, X, y; plot_set_loss=true, zoom=0)
p2 = contourf(mach.model, mach.fitresult, X, y; plot_classification_loss=true, target=target, zoom=0)
temp = 0.5
p1 = contourf(mach.model, mach.fitresult, X, y; plot_set_loss=true, zoom=0, temp=temp)
p2 = contourf(mach.model, mach.fitresult, X, y; plot_classification_loss=true, target=target, zoom=0, temp=temp, clim=nothing, loss_matrix=ones(2,2))
plot(p1, p2, size=(800,320))
```
......@@ -157,6 +158,7 @@ plot(plts..., size=(_n * img_size,1.05*img_size), layout=(1,_n))
#| fig-cap: "Comparison of counterfactuals produced using different generators."
opt = Descent(0.01)
temp = 0.75
ordered_names = [
"Generic (γ=0.5)",
"Generic (γ=0.9)",
......@@ -169,8 +171,8 @@ loss_fun = Objectives.logitbinarycrossentropy
generators = Dict(
ordered_names[1] => GenericGenerator(opt = opt, decision_threshold=0.5),
ordered_names[2] => GenericGenerator(opt = opt, decision_threshold=0.9),
ordered_names[3] => CCE.ConformalGenerator(opt=opt, λ=[0.1,1]),
ordered_names[4] => CCE.ConformalGenerator(opt=opt, λ=[0.1,10]),
ordered_names[3] => CCE.ConformalGenerator(opt=opt, λ=[0.1,1], temp=temp),
ordered_names[4] => CCE.ConformalGenerator(opt=opt, λ=[0.1,10], temp=temp),
)
counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen; initialization=:identity) for (name, gen) in generators])
......
......@@ -63,18 +63,21 @@ A configurable classification loss function for Conformal Predictors.
function conformal_training_loss(counterfactual_explanation::AbstractCounterfactualExplanation; kwargs...)
conf_model = counterfactual_explanation.M.model
fitresult = counterfactual_explanation.M.fitresult
generator = counterfactual_explanation.generator
temp = hasfield(typeof(generator), :temp) ? generator.temp : nothing
K = length(counterfactual_explanation.data.y_levels)
X = CounterfactualExplanations.decode_state(counterfactual_explanation)
y = counterfactual_explanation.target_encoded[:,:,1]
if counterfactual_explanation.M.likelihood == :classification_binary
y = binary_to_onehot(y)
end
y = permutedims(y)
generator = counterfactual_explanation.generator
loss = SliceMap.slicemap(X, dims=(1, 2)) do x
x = Matrix(x)
x = Matrix(x)
ConformalPrediction.classification_loss(
conf_model, fitresult, x, y;
temp=generator.temp
temp=temp,
loss_matrix=Float32.(ones(K,K))
)
end
loss = mean(loss)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment