penalties.jl 9.37 KiB
using ChainRules: ignore_derivatives
using CounterfactualExplanations: get_target_index
using Distances
using Flux
using Images: assess_ssim
using LinearAlgebra: norm
using Statistics: mean
"""
set_size_penalty(ce::AbstractCounterfactualExplanation)
Penalty for smooth conformal set size.
"""
function set_size_penalty(
ce::AbstractCounterfactualExplanation;
κ::Real=1.0, temp::Real=0.1, agg=mean
)
_loss = 0.0
conf_model = ce.M.model
fitresult = ce.M.fitresult
X = CounterfactualExplanations.decode_state(ce)
_loss = map(eachslice(X, dims=ndims(X))) do x
x = ndims(x) == 1 ? x[:,:] : x
if target_probs(ce, x)[1] >= 0.5
l = ConformalPrediction.ConformalTraining.smooth_size_loss(
conf_model, fitresult, x';
κ=κ,
temp=temp
)[1]
else
l = 0.0
end
return l
end
_loss = agg(_loss)
return _loss
end
function energy_delta(
ce::AbstractCounterfactualExplanation;
n::Int=50, niter=500, from_buffer=true, agg=mean,
choose_lowest_energy=true,
choose_random=false,
nmin::Int=25,
return_conditionals=false,
reg_strength=0.1,
kwargs...
)
xproposed = CounterfactualExplanations.decode_state(ce) # current state
t = get_target_index(ce.data.y_levels, ce.target)
E(x) = -logits(ce.M, x)[t,:] # negative logits for taraget class
# Generative loss:
gen_loss = E(xproposed)
gen_loss = reduce((x, y) -> x + y, gen_loss) / length(gen_loss) # aggregate over samples
# Regularization loss:
reg_loss = norm(E(xproposed))^2
reg_loss = reduce((x, y) -> x + y, reg_loss) / length(reg_loss) # aggregate over samples
return gen_loss + reg_strength * reg_loss
end
"""
distance_from_energy(ce::AbstractCounterfactualExplanation)
Computes the distance from the counterfactual to generated conditional samples.
"""
function distance_from_energy(
ce::AbstractCounterfactualExplanation;
n::Int=50, niter=500, from_buffer=true, agg=mean,
choose_lowest_energy=true,
choose_random=false,
nmin::Int=25,
return_conditionals=false,
p::Int=1,
kwargs...
)
_loss = 0.0
nmin = minimum([nmin, n])
@assert choose_lowest_energy ⊻ choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither."
conditional_samples = []
ignore_derivatives() do
_dict = ce.params
if !(:energy_sampler ∈ collect(keys(_dict)))
_dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
end
eng_sampler = _dict[:energy_sampler]
if choose_lowest_energy
nmin = minimum([nmin, size(eng_sampler.buffer)[end]])
xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin)
push!(conditional_samples, xmin)
elseif choose_random
push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer))
else
push!(conditional_samples, eng_sampler.buffer)
end
end
_loss = map(eachcol(conditional_samples[1])) do xsample
distance(ce; from=xsample, agg=agg, p=p)
end
_loss = reduce((x,y) -> x + y, _loss) / n # aggregate over samples
if return_conditionals
return conditional_samples[1]
end
return _loss
end
distance_from_energy_l2(ce::AbstractCounterfactualExplanation; kwrgs...) = distance_from_energy(ce; p=2, kwrgs...)
"""
distance_from_energy_cosine(ce::AbstractCounterfactualExplanation)
Computes the cosine distance from the counterfactual to generated conditional samples.
"""
function distance_from_energy_cosine(
ce::AbstractCounterfactualExplanation;
n::Int=50, niter=500, from_buffer=true, agg=mean,
choose_lowest_energy=true,
choose_random=false,
nmin::Int=25,
return_conditionals=false,
kwargs...
)
_loss = 0.0
nmin = minimum([nmin, n])
@assert choose_lowest_energy ⊻ choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither."
conditional_samples = []
ignore_derivatives() do
_dict = ce.params
if !(:energy_sampler ∈ collect(keys(_dict)))
_dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
end
eng_sampler = _dict[:energy_sampler]
if choose_lowest_energy
nmin = minimum([nmin, size(eng_sampler.buffer)[end]])
xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin)
push!(conditional_samples, xmin)
elseif choose_random
push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer))
else
push!(conditional_samples, eng_sampler.buffer)
end
end
_loss = map(eachcol(conditional_samples[1])) do xsample
cos_dist(CounterfactualExplanations.counterfactual(ce), xsample)
end
_loss = reduce((x,y) -> x + y, _loss) / n # aggregate over samples
if return_conditionals
return conditional_samples[1]
end
return _loss
end
"""
distance_from_energy_ssim(ce::AbstractCounterfactualExplanation)
Computes 1-SSIM from the counterfactual to generated conditional samples where SSIM is the Structural Similarity Index.
"""
function distance_from_energy_ssim(
ce::AbstractCounterfactualExplanation;
n::Int=50, niter=500, from_buffer=true, agg=mean,
choose_lowest_energy=true,
choose_random=false,
nmin::Int=25,
return_conditionals=false,
kwargs...
)
_loss = 0.0
nmin = minimum([nmin, n])
@assert choose_lowest_energy ⊻ choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither."
conditional_samples = []
ignore_derivatives() do
_dict = ce.params
if !(:energy_sampler ∈ collect(keys(_dict)))
_dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
end
eng_sampler = _dict[:energy_sampler]
if choose_lowest_energy
nmin = minimum([nmin, size(eng_sampler.buffer)[end]])
xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin)
push!(conditional_samples, xmin)
elseif choose_random
push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer))
else
push!(conditional_samples, eng_sampler.buffer)
end
end
_loss = map(eachcol(conditional_samples[1])) do xsample
ssim_dist(CounterfactualExplanations.counterfactual(ce), xsample)
end
_loss = reduce((x, y) -> x + y, _loss) / n # aggregate over samples
if return_conditionals
return conditional_samples[1]
end
return _loss
end
"""
distance_from_targets(ce::AbstractCounterfactualExplanation)
Computes the distance from the counterfactual to the N-nearest neighbors of the target class.
"""
function distance_from_targets(
ce::AbstractCounterfactualExplanation;
agg=mean,
n_nearest_neighbors::Union{Int,Nothing}=100,
p::Int=1,
)
target_idx = ce.data.output_encoder.labels .== ce.target
target_samples = ce.data.X[:,target_idx]
x′ = CounterfactualExplanations.counterfactual(ce)
loss = map(eachslice(x′, dims=ndims(x′))) do x
Δ = map(eachcol(target_samples)) do xsample
norm(x - xsample, p)
end
if n_nearest_neighbors != nothing
Δ = sort(Δ)[1:n_nearest_neighbors]
end
return mean(Δ)
end
loss = agg(loss)[1]
return loss
end
distance_from_targets_l2(ce::AbstractCounterfactualExplanation; kwrgs...) = distance_from_targets(ce; p=2, kwrgs...)
"""
distance_from_targets_cosine(ce::AbstractCounterfactualExplanation)
Computes the cosine distance from the counterfactual to the N-nearest neighbors of the target class.
"""
function distance_from_targets_cosine(
ce::AbstractCounterfactualExplanation;
agg=mean,
n_nearest_neighbors::Union{Int,Nothing}=100,
)
target_idx = ce.data.output_encoder.labels .== ce.target
target_samples = ce.data.X[:,target_idx]
x′ = CounterfactualExplanations.counterfactual(ce)
loss = map(eachslice(x′, dims=ndims(x′))) do x
Δ = map(eachcol(target_samples)) do xsample
cos_dist(x, xsample)
end
if n_nearest_neighbors != nothing
Δ = sort(Δ)[1:n_nearest_neighbors]
end
return mean(Δ)
end
loss = agg(loss)[1]
return loss
end
"""
distance_from_targets_ssim(ce::AbstractCounterfactualExplanation)
Computes the distance (1-SSIM) from the counterfactual to the N-nearest neighbors of the target class. SSIM is the Structural Similarity Index.
"""
function distance_from_targets_ssim(
ce::AbstractCounterfactualExplanation;
agg=mean,
n_nearest_neighbors::Union{Int,Nothing}=100,
)
target_idx = ce.data.output_encoder.labels .== ce.target
target_samples = ce.data.X[:, target_idx]
x′ = CounterfactualExplanations.counterfactual(ce)
loss = map(eachslice(x′, dims=ndims(x′))) do x
Δ = map(eachcol(target_samples)) do xsample
ssim_dist(x, xsample)
end
if n_nearest_neighbors != nothing
Δ = sort(Δ)[1:n_nearest_neighbors]
end
return mean(Δ)
end
loss = agg(loss)[1]
return loss
end