Skip to content
Snippets Groups Projects
penalties.jl 5.22 KiB
Newer Older
pat-alt's avatar
pat-alt committed
using ChainRules: ignore_derivatives
using CounterfactualExplanations: get_target_index
pat-alt's avatar
pat-alt committed
using Distances
pat-alt's avatar
pat-alt committed
using LinearAlgebra: norm
pat-alt's avatar
pat-alt committed
using Statistics: mean

"""
Pat Alt's avatar
Pat Alt committed
    set_size_penalty(ce::AbstractCounterfactualExplanation)
pat-alt's avatar
pat-alt committed

Penalty for smooth conformal set size.
"""
function set_size_penalty(
Pat Alt's avatar
Pat Alt committed
    ce::AbstractCounterfactualExplanation; 
    κ::Real=1.0, temp::Real=0.1, agg=mean
pat-alt's avatar
pat-alt committed
)

    _loss = 0.0

Pat Alt's avatar
Pat Alt committed
    conf_model = ce.M.model
    fitresult = ce.M.fitresult
    X = CounterfactualExplanations.decode_state(ce)
    _loss = map(eachslice(X, dims=ndims(X))) do x
Pat Alt's avatar
Pat Alt committed
        x = ndims(x) == 1 ? x[:,:] : x
        if target_probs(ce, x)[1] >= 0.5
            l = ConformalPrediction.ConformalTraining.smooth_size_loss(
pat-alt's avatar
pat-alt committed
                conf_model, fitresult, x';
pat-alt's avatar
pat-alt committed
                κ=κ,
                temp=temp
            )[1]
        else 
            l = 0.0
        end
        return l
    end
    _loss = agg(_loss)
pat-alt's avatar
pat-alt committed

    return _loss
pat-alt's avatar
pat-alt committed

end

Pat Alt's avatar
Pat Alt committed
function energy_delta(
    ce::AbstractCounterfactualExplanation;
    n::Int=50, niter=500, from_buffer=true, agg=mean,
    choose_lowest_energy=true,
    choose_random=false,
    nmin::Int=25,
    return_conditionals=false,
Pat Alt's avatar
Pat Alt committed
    reg_strength=0.1,
Pat Alt's avatar
Pat Alt committed
    kwargs...
)

Pat Alt's avatar
Pat Alt committed
    # nmin = minimum([nmin, n])

    # @assert choose_lowest_energy ⊻ choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither."

    # conditional_samples = []
    # ignore_derivatives() do
    #     _dict = ce.params
    #     if !(:energy_sampler ∈ collect(keys(_dict)))
    #         _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
    #     end
    #     eng_sampler = _dict[:energy_sampler]
    #     if choose_lowest_energy
    #         nmin = minimum([nmin, size(eng_sampler.buffer)[end]])
    #         xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin)
    #         push!(conditional_samples, xmin)
    #     elseif choose_random
    #         push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer))
    #     else
    #         push!(conditional_samples, eng_sampler.buffer)
    #     end
    # end
Pat Alt's avatar
Pat Alt committed
    conditional_samples = []
Pat Alt's avatar
Pat Alt committed
    ignore_derivatives() do 
        xsampled = ECCCo.EnergySampler(ce; niter=niter, nsamples=ce.num_counterfactuals, kwargs...)
        push!(conditional_samples, xsampled)
Pat Alt's avatar
Pat Alt committed
    end

Pat Alt's avatar
uh  
Pat Alt committed
    xgenerated = conditional_samples[1]                         # conditional samples
    xproposed = CounterfactualExplanations.decode_state(ce)     # current state
    t = get_target_index(ce.data.y_levels, ce.target)
Pat Alt's avatar
uh  
Pat Alt committed
    E(x) = -logits(ce.M, x)[t,:]                                # negative logits for target class
Pat Alt's avatar
Pat Alt committed

    # Generative loss:
    gen_loss = E(xproposed) .- E(xgenerated)
Pat Alt's avatar
Pat Alt committed
    gen_loss = reduce((x, y) -> x + y, gen_loss) / length(gen_loss)                  # aggregate over samples

    # Regularization loss:
    reg_loss = E(xgenerated).^2 .+ E(xproposed).^2
Pat Alt's avatar
Pat Alt committed
    reg_loss = reduce((x, y) -> x + y, reg_loss) / length(reg_loss)                  # aggregate over samples
Pat Alt's avatar
Pat Alt committed

Pat Alt's avatar
Pat Alt committed
    if !return_conditionals
        return gen_loss + reg_strength * reg_loss
    else
Pat Alt's avatar
Pat Alt committed
        return conditional_samples[1]
    end

end

pat-alt's avatar
pat-alt committed
function distance_from_energy(
Pat Alt's avatar
Pat Alt committed
    ce::AbstractCounterfactualExplanation;
    n::Int=50, niter=500, from_buffer=true, agg=mean, 
Pat Alt's avatar
Pat Alt committed
    choose_lowest_energy=true,
    choose_random=false,
Pat Alt's avatar
Pat Alt committed
    return_conditionals=false,
    kwargs...
pat-alt's avatar
pat-alt committed
)
    _loss = 0.0
Pat Alt's avatar
Pat Alt committed
    nmin = minimum([nmin, n])
Pat Alt's avatar
Pat Alt committed
    @assert choose_lowest_energy  choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither."

pat-alt's avatar
pat-alt committed
    conditional_samples = []
    ignore_derivatives() do
Pat Alt's avatar
Pat Alt committed
        _dict = ce.params
pat-alt's avatar
pat-alt committed
        if !(:energy_sampler  collect(keys(_dict)))
            _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
pat-alt's avatar
pat-alt committed
        end
Pat Alt's avatar
Pat Alt committed
        eng_sampler = _dict[:energy_sampler]
        if choose_lowest_energy
Pat Alt's avatar
Pat Alt committed
            nmin = minimum([nmin, size(eng_sampler.buffer)[end]])
Pat Alt's avatar
Pat Alt committed
            xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin)
            push!(conditional_samples, xmin)
        elseif choose_random
            push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer))
        else
            push!(conditional_samples, eng_sampler.buffer)
        end
pat-alt's avatar
pat-alt committed
    end

    _loss = map(eachcol(conditional_samples[1])) do xsample
        distance_l1(ce; from=xsample, agg=agg)
pat-alt's avatar
pat-alt committed
    end
    _loss = reduce((x,y) -> x + y, _loss) / n       # aggregate over samples
Pat Alt's avatar
Pat Alt committed

    if return_conditionals
        return conditional_samples[1]
    end
    return _loss
pat-alt's avatar
pat-alt committed

end

function distance_from_targets(
Pat Alt's avatar
Pat Alt committed
    ce::AbstractCounterfactualExplanation;
Pat Alt's avatar
Pat Alt committed
    n::Int=1000, agg=mean,
    n_nearest_neighbors::Union{Int,Nothing}=nothing,
Pat Alt's avatar
Pat Alt committed
    target_idx = ce.data.output_encoder.labels .== ce.target
    target_samples = ce.data.X[:,target_idx] |>
        X -> X[:,rand(1:end,n)]
Pat Alt's avatar
Pat Alt committed
    x′ = CounterfactualExplanations.counterfactual(ce)
Pat Alt's avatar
Pat Alt committed
    loss = map(eachslice(x′, dims=ndims(x′))) do x
        Δ = map(eachcol(target_samples)) do xsample
Pat Alt's avatar
Pat Alt committed
            norm(x - xsample, 1)
Pat Alt's avatar
Pat Alt committed
        if n_nearest_neighbors != nothing
            Δ = sort(Δ)[1:n_nearest_neighbors]
        end
pat-alt's avatar
pat-alt committed
    loss = agg(loss)[1]