Skip to content
Snippets Groups Projects
generator.jl 1.28 KiB
Newer Older
Pat Alt's avatar
Pat Alt committed
using CounterfactualExplanations.Objectives
using CounterfactualExplanations.Generators: GradientBasedGenerator
Pat Alt's avatar
Pat Alt committed

pat-alt's avatar
pat-alt committed
"Constructor for `ECECCCoGenerator`: Energy Constrained Conformal Counterfactual Explanation Generator."
function ECCCoGenerator(; 
Pat Alt's avatar
Pat Alt committed
    λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.2,0.4,0.4], 
Pat Alt's avatar
uh  
Pat Alt committed
    κ::Real=1.0, 
Pat Alt's avatar
Pat Alt committed
    temp::Real=0.1, 
    opt::Union{Nothing,Flux.Optimise.AbstractOptimiser}=nothing,
    use_class_loss::Bool=false,
Pat Alt's avatar
Pat Alt committed
    nsamples::Int=50,
    nmin::Int=25,
Pat Alt's avatar
Pat Alt committed
    use_energy_delta::Bool=false,
Pat Alt's avatar
uh  
Pat Alt committed
    kwargs...
)

    # Default optimiser
Pat Alt's avatar
Pat Alt committed
    if isnothing(opt)
        opt = CounterfactualExplanations.Generators.Descent(0.1)
    end

    # Loss function
    if use_class_loss
        loss_fun(ce::AbstractCounterfactualExplanation) = conformal_training_loss(ce; temp=temp)
    else
        loss_fun = nothing
    end

    _energy_penalty =
        use_energy_delta ? (ECCCo.energy_delta, (n=nsamples, nmin=nmin)) : (ECCCo.distance_from_energy, (n=nsamples, nmin=nmin))
Pat Alt's avatar
Pat Alt committed

    _penalties = [
        (Objectives.distance_l1, []), 
        (ECCCo.set_size_penalty, (κ=κ, temp=temp)),
        _energy_penalty,
    ]
Pat Alt's avatar
uh  
Pat Alt committed
    λ = λ isa AbstractFloat ? [0.0, λ, λ] : λ

    # Generator
    return GradientBasedGenerator(; loss=loss_fun, penalty=_penalties, λ=λ, opt=opt, kwargs...)
Pat Alt's avatar
Pat Alt committed
end