ConformalGenerator.jl 3.61 KiB
using CounterfactualExplanations
using CounterfactualExplanations.Generators
using Flux
using LinearAlgebra
using Parameters
using Statistics
mutable struct ConformalGenerator <: AbstractGradientBasedGenerator
loss::Union{Nothing,Symbol} # loss function
complexity::Function # complexity function
λ::Union{AbstractFloat,AbstractVector} # strength of penalty
decision_threshold::Union{Nothing,AbstractFloat}
opt::Flux.Optimise.AbstractOptimiser # optimizer
τ::AbstractFloat # tolerance for convergence
end
# API streamlining:
@with_kw struct ConformalGeneratorParams
opt::Flux.Optimise.AbstractOptimiser = Descent()
τ::AbstractFloat = 1e-3
end
"""
ConformalGenerator(
;
loss::Symbol=:logitbinarycrossentropy,
complexity::Function=norm,
λ::AbstractFloat=0.1,
opt::Flux.Optimise.AbstractOptimiser=Flux.Optimise.Descent(),
τ::AbstractFloat=1e-5
)
An outer constructor method that instantiates a generic generator.
# Examples
```julia-repl
generator = ConformalGenerator()
```
"""
function ConformalGenerator(;
loss::Union{Nothing,Symbol} = nothing,
complexity::Function = norm,
λ::Union{AbstractFloat,AbstractVector} = [0.1, 1.0],
decision_threshold = nothing,
kwargs...,
)
params = ConformalGeneratorParams(; kwargs...)
ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ)
end
# Loss:
# """
# ℓ(generator::ConformalGenerator, counterfactual_explanation::AbstractCounterfactualExplanation)
# The default method to apply the generator loss function to the current counterfactual state for any generator.
# """
# function ℓ(
# generator::ConformalGenerator,
# counterfactual_explanation::AbstractCounterfactualExplanation,
# )
# loss_fun =
# !isnothing(generator.loss) ? getfield(Losses, generator.loss) :
# CounterfactualExplanations.guess_loss(counterfactual_explanation)
# @assert !isnothing(loss_fun) "No loss function provided and loss function could not be guessed based on model."
# loss = loss_fun(
# getfield(Models, :logits)(
# counterfactual_explanation.M,
# CounterfactualExplanations.decode_state(counterfactual_explanation),
# ),
# counterfactual_explanation.target_encoded,
# )
# return loss
# end
"""
set_size_penalty(
generator::ConformalGenerator,
counterfactual_explanation::AbstractCounterfactualExplanation,
)
Additional penalty for ConformalGenerator.
"""
function set_size_penalty(
generator::ConformalGenerator,
counterfactual_explanation::AbstractCounterfactualExplanation,
)
end
# Complexity:
"""
h(generator::AbstractGenerator, counterfactual_explanation::AbstractCounterfactualExplanation)
The default method to apply the generator complexity penalty to the current counterfactual state for any generator.
"""
function Generators.h(
generator::ConformalGenerator,
counterfactual_explanation::AbstractCounterfactualExplanation,
)
# Distance from factual:
dist_ = generator.complexity(
counterfactual_explanation.x .-
CounterfactualExplanations.decode_state(counterfactual_explanation),
)
# Euclidean norm of gradient:
in_target_domain = all(target_probs(counterfactual_explanation) .>= 0.5)
if in_target_domain
grad_norm = gradient_penalty(generator, counterfactual_explanation)
else
grad_norm = 0
end
if length(generator.λ) == 1
penalty = generator.λ * (dist_ .+ grad_norm)
else
penalty = generator.λ[1] * dist_ .+ generator.λ[2] * grad_norm
end
return penalty
end