Skip to content
Snippets Groups Projects
generator.jl 1.94 KiB
Newer Older
Pat Alt's avatar
Pat Alt committed
using CounterfactualExplanations.Objectives

"Constructor for `CCEGenerator`."
function CCEGenerator(; 
    λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], 
    κ::Real=1.0, 
    temp::Real=0.5, 
    kwargs...
)
Pat Alt's avatar
Pat Alt committed
    function _set_size_penalty(ce::AbstractCounterfactualExplanation)
pat-alt's avatar
pat-alt committed
        return ECCCo.set_size_penalty(ce; κ=κ, temp=temp)
Pat Alt's avatar
Pat Alt committed
    end
    _penalties = [Objectives.distance_l2, _set_size_penalty]
    λ = λ isa AbstractFloat ? [0.0, λ] : λ
    return Generator(; penalty=_penalties, λ=λ, kwargs...)
Pat Alt's avatar
Pat Alt committed
end

pat-alt's avatar
pat-alt committed
"Constructor for `ECECCCoGenerator`: Energy Constrained Conformal Counterfactual Explanation Generator."
function ECCCoGenerator(; 
Pat Alt's avatar
uh  
Pat Alt committed
    λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0, 1.0], 
    κ::Real=1.0, 
    temp::Real=0.5, 
    η::Union{Nothing,Real}=nothing,
    n::Union{Nothing,Int}=nothing,
    opt::Flux.Optimise.AbstractOptimiser=CounterfactualExplanations.Generators.JSMADescent(η=η,n=n),
    kwargs...
)
    function _set_size_penalty(ce::AbstractCounterfactualExplanation)
pat-alt's avatar
pat-alt committed
        return ECCCo.set_size_penalty(ce; κ=κ, temp=temp)
Pat Alt's avatar
uh  
Pat Alt committed
    end
    _penalties = [Objectives.distance_l1, _set_size_penalty, ECCCo.distance_from_energy]
Pat Alt's avatar
uh  
Pat Alt committed
    λ = λ isa AbstractFloat ? [0.0, λ, λ] : λ
    return Generator(; penalty=_penalties, λ=λ, opt=opt, kwargs...)
end

Pat Alt's avatar
Pat Alt committed
"Constructor for `EnergyDrivenGenerator`."
function EnergyDrivenGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], kwargs...)
pat-alt's avatar
pat-alt committed
    _penalties = [Objectives.distance_l2, ECCCo.distance_from_energy]
Pat Alt's avatar
Pat Alt committed
    λ = λ isa AbstractFloat ? [0.0, λ] : λ
    return Generator(; penalty=_penalties, λ=λ, kwargs...)
end

"Constructor for `TargetDrivenGenerator`."
function TargetDrivenGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], kwargs...)
pat-alt's avatar
pat-alt committed
    _penalties = [Objectives.distance_l2, ECCCo.distance_from_targets]
Pat Alt's avatar
Pat Alt committed
    λ = λ isa AbstractFloat ? [0.0, λ] : λ
    return Generator(; penalty=_penalties, λ=λ, kwargs...)
Pat Alt's avatar
Pat Alt committed
end