Skip to content
Snippets Groups Projects
ConformalGenerator.jl 3.61 KiB
using CounterfactualExplanations
using CounterfactualExplanations.Generators
using Flux
using LinearAlgebra
using Parameters
using Statistics

mutable struct ConformalGenerator <: AbstractGradientBasedGenerator
    loss::Union{Nothing,Symbol} # loss function
    complexity::Function # complexity function
    λ::Union{AbstractFloat,AbstractVector} # strength of penalty
    decision_threshold::Union{Nothing,AbstractFloat}
    opt::Flux.Optimise.AbstractOptimiser # optimizer
    τ::AbstractFloat # tolerance for convergence
end

# API streamlining:
@with_kw struct ConformalGeneratorParams
    opt::Flux.Optimise.AbstractOptimiser = Descent()
    τ::AbstractFloat = 1e-3
end

"""
    ConformalGenerator(
        ;
        loss::Symbol=:logitbinarycrossentropy,
        complexity::Function=norm,
        λ::AbstractFloat=0.1,
        opt::Flux.Optimise.AbstractOptimiser=Flux.Optimise.Descent(),
        τ::AbstractFloat=1e-5
    )

An outer constructor method that instantiates a generic generator.

# Examples
```julia-repl
generator = ConformalGenerator()
```
"""
function ConformalGenerator(;
    loss::Union{Nothing,Symbol} = nothing,
    complexity::Function = norm,
    λ::Union{AbstractFloat,AbstractVector} = [0.1, 1.0],
    decision_threshold = nothing,
    kwargs...,
)
    params = ConformalGeneratorParams(; kwargs...)
    ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ)
end

# Loss:
# """
#     ℓ(generator::ConformalGenerator, counterfactual_explanation::AbstractCounterfactualExplanation)

# The default method to apply the generator loss function to the current counterfactual state for any generator.
# """
# function ℓ(
#     generator::ConformalGenerator,
#     counterfactual_explanation::AbstractCounterfactualExplanation,
# )

#     loss_fun =
#         !isnothing(generator.loss) ? getfield(Losses, generator.loss) :
#         CounterfactualExplanations.guess_loss(counterfactual_explanation)
#     @assert !isnothing(loss_fun) "No loss function provided and loss function could not be guessed based on model."
#     loss = loss_fun(
#         getfield(Models, :logits)(
#             counterfactual_explanation.M,
#             CounterfactualExplanations.decode_state(counterfactual_explanation),
#         ),
#         counterfactual_explanation.target_encoded,
#     )
#     return loss
# end

"""
    set_size_penalty(
        generator::ConformalGenerator,
        counterfactual_explanation::AbstractCounterfactualExplanation,
    )

Additional penalty for ConformalGenerator.
"""
function set_size_penalty(
    generator::ConformalGenerator,
    counterfactual_explanation::AbstractCounterfactualExplanation,
)


end

# Complexity:
"""
    h(generator::AbstractGenerator, counterfactual_explanation::AbstractCounterfactualExplanation)

The default method to apply the generator complexity penalty to the current counterfactual state for any generator.
"""
function Generators.h(
    generator::ConformalGenerator,
    counterfactual_explanation::AbstractCounterfactualExplanation,
)

    # Distance from factual:
    dist_ = generator.complexity(
        counterfactual_explanation.x .-
        CounterfactualExplanations.decode_state(counterfactual_explanation),
    )

    # Euclidean norm of gradient:
    in_target_domain = all(target_probs(counterfactual_explanation) .>= 0.5)
    if in_target_domain
        grad_norm = gradient_penalty(generator, counterfactual_explanation)
    else
        grad_norm = 0
    end

    if length(generator.λ) == 1
        penalty = generator.λ * (dist_ .+ grad_norm)
    else
        penalty = generator.λ[1] * dist_ .+ generator.λ[2] * grad_norm
    end
    return penalty
end