Skip to content
Snippets Groups Projects
penalties.jl 2.15 KiB
Newer Older
pat-alt's avatar
pat-alt committed
using ChainRules: ignore_derivatives
pat-alt's avatar
pat-alt committed
using Distances
pat-alt's avatar
pat-alt committed
using LinearAlgebra: norm
pat-alt's avatar
pat-alt committed
using Statistics: mean

"""
Pat Alt's avatar
Pat Alt committed
    set_size_penalty(ce::AbstractCounterfactualExplanation)
pat-alt's avatar
pat-alt committed

Penalty for smooth conformal set size.
"""
function set_size_penalty(
Pat Alt's avatar
Pat Alt committed
    ce::AbstractCounterfactualExplanation; 
Pat Alt's avatar
Pat Alt committed
    κ::Real=0.0, temp::Real=0.05, agg=mean
pat-alt's avatar
pat-alt committed
)

Pat Alt's avatar
Pat Alt committed
    conf_model = ce.M.model
    fitresult = ce.M.fitresult
    X = CounterfactualExplanations.decode_state(ce)
    loss = map(eachslice(X, dims=ndims(X))) do x
        x = ndims(x) == 1 ? x[:,:] : x
        if target_probs(ce, x)[1] >= 0.5
pat-alt's avatar
pat-alt committed
            l = ConformalPrediction.smooth_size_loss(
pat-alt's avatar
pat-alt committed
                conf_model, fitresult, x';
pat-alt's avatar
pat-alt committed
                κ=κ,
                temp=temp
            )[1]
        else 
            l = 0.0
        end
        return l
    end
    loss = agg(loss)

    return loss

end

function distance_from_energy(
Pat Alt's avatar
Pat Alt committed
    ce::AbstractCounterfactualExplanation;
Pat Alt's avatar
Pat Alt committed
    n::Int=10, niter=60, from_buffer=true, agg=mean, kwargs...
pat-alt's avatar
pat-alt committed
)
pat-alt's avatar
pat-alt committed
    conditional_samples = []
    ignore_derivatives() do
Pat Alt's avatar
Pat Alt committed
        _dict = ce.params
pat-alt's avatar
pat-alt committed
        if !(:energy_sampler  collect(keys(_dict)))
            _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
pat-alt's avatar
pat-alt committed
        end
        sampler = _dict[:energy_sampler]
Pat Alt's avatar
Pat Alt committed
        push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer))
pat-alt's avatar
pat-alt committed
    end
Pat Alt's avatar
Pat Alt committed
    x′ = CounterfactualExplanations.counterfactual(ce)
    loss = map(eachslice(x′, dims=ndims(x′))) do x
pat-alt's avatar
pat-alt committed
        Δ = map(eachcol(conditional_samples[1])) do xsample
pat-alt's avatar
pat-alt committed
            norm(x - xsample, 1)
pat-alt's avatar
pat-alt committed
        end
        return mean(Δ)
pat-alt's avatar
pat-alt committed
    end
pat-alt's avatar
pat-alt committed
    loss = agg(loss)[1]
pat-alt's avatar
pat-alt committed

    return loss

end

function distance_from_targets(
Pat Alt's avatar
Pat Alt committed
    ce::AbstractCounterfactualExplanation;
Pat Alt's avatar
Pat Alt committed
    n::Int=100, agg=mean
Pat Alt's avatar
Pat Alt committed
    target_idx = ce.data.output_encoder.labels .== ce.target
    target_samples = ce.data.X[:,target_idx] |>
        X -> X[:,rand(1:end,n)]
Pat Alt's avatar
Pat Alt committed
    x′ = CounterfactualExplanations.counterfactual(ce)
Pat Alt's avatar
Pat Alt committed
    loss = map(eachslice(x′, dims=ndims(x′))) do x
        Δ = map(eachcol(target_samples)) do xsample
Pat Alt's avatar
Pat Alt committed
            norm(x - xsample, 1)
pat-alt's avatar
pat-alt committed
    loss = agg(loss)[1]