Newer
Older
Pat Alt
committed
using CounterfactualExplanations: get_target_index
Penalty for smooth conformal set size.
"""
function set_size_penalty(
conf_model = ce.M.model
fitresult = ce.M.fitresult
X = CounterfactualExplanations.decode_state(ce)
x = ndims(x) == 1 ? x[:,:] : x
if target_probs(ce, x)[1] >= 0.5
Pat Alt
committed
l = ConformalPrediction.ConformalTraining.smooth_size_loss(
κ=κ,
temp=temp
)[1]
else
l = 0.0
end
return l
end
function energy_delta(
ce::AbstractCounterfactualExplanation;
n::Int=50, niter=500, from_buffer=true, agg=mean,
choose_lowest_energy=true,
choose_random=false,
nmin::Int=25,
return_conditionals=false,
# nmin = minimum([nmin, n])
# @assert choose_lowest_energy ⊻ choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither."
# conditional_samples = []
# ignore_derivatives() do
# _dict = ce.params
# if !(:energy_sampler ∈ collect(keys(_dict)))
# _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
# end
# eng_sampler = _dict[:energy_sampler]
# if choose_lowest_energy
# nmin = minimum([nmin, size(eng_sampler.buffer)[end]])
# xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin)
# push!(conditional_samples, xmin)
# elseif choose_random
# push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer))
# else
# push!(conditional_samples, eng_sampler.buffer)
# end
# end
ignore_derivatives() do
xsampled = ECCCo.EnergySampler(ce; niter=niter, nsamples=ce.num_counterfactuals, kwargs...)
push!(conditional_samples, xsampled)
xgenerated = conditional_samples[1] # conditional samples
xproposed = CounterfactualExplanations.decode_state(ce) # current state
Pat Alt
committed
t = get_target_index(ce.data.y_levels, ce.target)
# Generative loss:
gen_loss = E(xproposed) .- E(xgenerated)
gen_loss = reduce((x, y) -> x + y, gen_loss) / length(gen_loss) # aggregate over samples
# Regularization loss:
reg_loss = E(xgenerated).^2 .+ E(xproposed).^2
reg_loss = reduce((x, y) -> x + y, reg_loss) / length(reg_loss) # aggregate over samples
if !return_conditionals
return gen_loss + reg_strength * reg_loss
else
n::Int=50, niter=500, from_buffer=true, agg=mean,
choose_lowest_energy=true,
choose_random=false,
@assert choose_lowest_energy ⊻ choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither."
_dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
eng_sampler = _dict[:energy_sampler]
if choose_lowest_energy
xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin)
push!(conditional_samples, xmin)
elseif choose_random
push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer))
else
push!(conditional_samples, eng_sampler.buffer)
end
_loss = map(eachcol(conditional_samples[1])) do xsample
distance_l1(ce; from=xsample, agg=agg)
_loss = reduce((x,y) -> x + y, _loss) / n # aggregate over samples
if return_conditionals
return conditional_samples[1]
end
end
function distance_from_targets(
n::Int=1000, agg=mean,
n_nearest_neighbors::Union{Int,Nothing}=nothing,
target_idx = ce.data.output_encoder.labels .== ce.target
target_samples = ce.data.X[:,target_idx] |>
X -> X[:,rand(1:end,n)]
Δ = map(eachcol(target_samples)) do xsample
if n_nearest_neighbors != nothing
Δ = sort(Δ)[1:n_nearest_neighbors]
end