Skip to content
Snippets Groups Projects
generator.jl 1.95 KiB
using CounterfactualExplanations.Objectives
using CounterfactualExplanations.Generators: GradientBasedGenerator

"Constructor for `ECECCCoGenerator`: Energy Constrained Conformal Counterfactual Explanation Generator."
function ECCCoGenerator(;
    λ::Union{AbstractFloat,Vector{<:AbstractFloat}} = [0.2, 0.4, 0.4],
    κ::Real = 1.0,
    temp::Real = 0.1,
    opt::Union{Nothing,Flux.Optimise.AbstractOptimiser} = nothing,
    use_class_loss::Bool = false,
    use_energy_delta::Bool = false,
    nsamples::Union{Nothing,Int} = nothing,
    nmin::Union{Nothing,Int} = nothing,
    niter::Union{Nothing,Int} = nothing,
    reg_strength::Real = 0.1,
    decay::Tuple = (0.1, 1),
    dim_reduction::Bool = false,
    kwargs...,
)

    # Default ECCCo parameters
    nsamples = isnothing(nsamples) ? 10 : nsamples
    nmin = isnothing(nmin) ? 1 : nmin
    niter = isnothing(niter) ? 100 : niter

    # Default optimiser
    if isnothing(opt)
        opt = CounterfactualExplanations.Generators.Descent(0.1)
    end

    # Loss function
    if use_class_loss
        loss_fun(ce::AbstractCounterfactualExplanation) =
            conformal_training_loss(ce; temp = temp)
    else
        loss_fun = nothing
    end

    _energy_penalty =
        use_energy_delta ?
        (
            ECCCo.energy_delta,
            (
                n = nsamples,
                nmin = nmin,
                niter = niter,
                reg_strength = reg_strength,
                decay = decay,
            ),
        ) : (ECCCo.distance_from_energy, (n = nsamples, nmin = nmin, niter = niter))

    _penalties = [
        (Objectives.distance_l1, []),
        (ECCCo.set_size_penalty, (κ = κ, temp = temp)),
        _energy_penalty,
    ]
    λ = λ isa AbstractFloat ? [0.0, λ, λ] : λ

    # Generator
    return GradientBasedGenerator(;
        loss = loss_fun,
        penalty = _penalties,
        λ = λ,
        opt = opt,
        dim_reduction = dim_reduction,
        kwargs...,
    )
end