Skip to content
Snippets Groups Projects
losses.jl 1004 B
using CounterfactualExplanations
using Statistics: mean

@doc raw"""
    conformal_training_loss(ce::AbstractCounterfactualExplanation; temp::Real=0.1, agg=mean, kwargs...)

A configurable classification loss function for Conformal Predictors.
"""
function conformal_training_loss(ce::AbstractCounterfactualExplanation; temp::Real=0.1, agg=mean, kwargs...)
    conf_model = ce.M.model
    fitresult = ce.M.fitresult
    X = CounterfactualExplanations.decode_state(ce)
    y = ce.target_encoded[:, :, 1]
    if ce.M.likelihood == :classification_binary
        y = binary_to_onehot(y)
    end
    y = permutedims(y)

    n_classes = length(ce.data.y_levels)
    loss_mat = ones(n_classes, n_classes)
    loss = map(eachslice(X, dims=ndims(X))) do x
        x = ndims(x) == 1 ? x[:,:]' : x
        ConformalPrediction.ConformalTraining.classification_loss(
            conf_model, fitresult, x, y;
            temp=temp, loss_matrix = loss_mat,
        )
    end
    loss = agg(loss)[1]
    return loss
end