Skip to content
Snippets Groups Projects
penalties.jl 2.31 KiB
using ChainRules: ignore_derivatives
using LinearAlgebra: norm
using Statistics: mean

"""
    set_size_penalty(counterfactual_explanation::AbstractCounterfactualExplanation)

Penalty for smooth conformal set size.
"""
function set_size_penalty(
    counterfactual_explanation::AbstractCounterfactualExplanation; 
    κ::Real=0.0, temp::Real=0.05, agg=mean
)

    conf_model = counterfactual_explanation.M.model
    fitresult = counterfactual_explanation.M.fitresult
    X = CounterfactualExplanations.decode_state(counterfactual_explanation)
    loss = map(eachslice(X, dims=3)) do x
        x = Matrix(x)
        if target_probs(counterfactual_explanation, x)[1] >= 0.5
            l = ConformalPrediction.smooth_size_loss(
                conf_model, fitresult, x';
                κ=κ,
                temp=temp
            )[1]
        else 
            l = 0.0
        end
        return l
    end
    loss = agg(loss)

    return loss

end

function distance_from_energy(
    counterfactual_explanation::AbstractCounterfactualExplanation;
    n::Int=100, from_buffer=true, agg=mean, kwargs...
)
    conditional_samples = []
    ignore_derivatives() do
        _dict = counterfactual_explanation.params
        if !(:energy_sampler ∈ collect(keys(_dict)))
            _dict[:energy_sampler] = CCE.EnergySampler(counterfactual_explanation; kwargs...)
        end
        sampler = _dict[:energy_sampler]
        push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer))
    end
    x′ = CounterfactualExplanations.counterfactual(counterfactual_explanation)
    loss = map(eachslice(x′, dims=3)) do x
        x = Matrix(x)
        Δ = map(eachcol(conditional_samples[1])) do xsample
            norm(x - xsample)
        end
        return mean(Δ)
    end
    loss = agg(loss)[1]

    return loss

end

function distance_from_targets(
    counterfactual_explanation::AbstractCounterfactualExplanation;
    n::Int=100, agg=mean
)
    target_samples = counterfactual_explanation.data.X |>
        X -> X[:,rand(1:end,n)]
    x′ = CounterfactualExplanations.counterfactual(counterfactual_explanation)
    loss = map(eachslice(x′, dims=3)) do x
        x = Matrix(x)
        Δ = map(eachcol(target_samples)) do xsample
            norm(x - xsample)
        end
        return mean(Δ)
    end
    loss = agg(loss)[1]

    return loss

end