Newer
Older
Pat Alt
committed
using CounterfactualExplanations.Generators: GradientBasedGenerator
"Constructor for `ECECCCoGenerator`: Energy Constrained Conformal Counterfactual Explanation Generator."
temp::Real=0.1,
opt::Union{Nothing,Flux.Optimise.AbstractOptimiser}=nothing,
if isnothing(opt)
opt = CounterfactualExplanations.Generators.Descent(0.1)
end
# Loss function
if use_class_loss
loss_fun(ce::AbstractCounterfactualExplanation) = conformal_training_loss(ce; temp=temp)
else
loss_fun = nothing
end
_energy_penalty =
use_energy_delta ? (ECCCo.energy_delta, (n=nsamples, nmin=nmin)) : (ECCCo.distance_from_energy, (n=nsamples, nmin=nmin))
_penalties = [
(Objectives.distance_l1, []),
(ECCCo.set_size_penalty, (κ=κ, temp=temp)),
_energy_penalty,
]
Pat Alt
committed
return GradientBasedGenerator(; loss=loss_fun, penalty=_penalties, λ=λ, opt=opt, kwargs...)