sampling.jl 4.00 KiB
using CounterfactualExplanations
using Distributions
using Flux
using JointEnergyModels
"""
(model::AbstractFittedModel)(x)
When called on data `x`, softmax logits are returned. In the binary case, outputs are one-hot encoded.
"""
(model::AbstractFittedModel)(x) = log.(CounterfactualExplanations.predict_proba(model, nothing, x))
"Base type that stores information relevant to energy-based posterior sampling from `AbstractFittedModel`."
mutable struct EnergySampler
model::AbstractFittedModel
data::CounterfactualData
sampler::JointEnergyModels.ConditionalSampler
opt::JointEnergyModels.AbstractSamplingRule
buffer::Union{Nothing,AbstractArray}
yidx::Union{Nothing,Any}
end
"""
EnergySampler(
model::AbstractFittedModel,
data::CounterfactualData,
y::Any;
opt::JointEnergyModels.AbstractSamplingRule=ImproperSGLD(),
niter::Int=100,
nsamples::Int=1000
)
Constructor for `EnergySampler` that takes a `model`, `data` and conditioning value `y` as inputs.
"""
function EnergySampler(
model::AbstractFittedModel,
data::CounterfactualData,
y::Any;
opt::JointEnergyModels.AbstractSamplingRule=ImproperSGLD(),
niter::Int=100,
nsamples::Int=100
)
@assert y ∈ data.y_levels || y ∈ 1:length(data.y_levels)
K = length(data.y_levels)
input_size = size(selectdim(data.X, ndims(data.X), 1))
𝒟x = Uniform(extrema(data.X)...)
𝒟y = Categorical(ones(K) ./ K)
sampler = ConditionalSampler(𝒟x, 𝒟y; input_size=input_size)
yidx = get_target_index(data.y_levels, y)
# Initiate:
energy_sampler = EnergySampler(model, data, sampler, opt, nothing, yidx)
# Generate conditional:
generate_samples!(energy_sampler, nsamples, yidx; niter=niter)
return energy_sampler
end
"""
EnergySampler(
ce::CounterfactualExplanation;
kwrgs...
)
Constructor for `EnergySampler` that takes a `CounterfactualExplanation` as input. The underlying model, data and `target` are used for the `EnergySampler`, where `target` is the conditioning value of `y`.
"""
function EnergySampler(
ce::CounterfactualExplanation;
kwrgs...
)
# Setup:
model = ce.M
data = ce.data
y = ce.target
return EnergySampler(model, data, y; kwrgs...)
end
"""
generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100)
Generates `n` samples from `EnergySampler` for conditioning value `y`.
"""
function generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100)
# Generate samples:
f(x) = logits(e.model, x)
rule = e.opt
xsamples = e.sampler(f, rule; niter=niter, n_samples=n, y=y)
return xsamples
end
"""
generate_samples!(e::EnergySampler, n::Int, y::Int; niter::Int=100)
Generates `n` samples from `EnergySampler` for conditioning value `y`. Assigns samples and conditioning value to `EnergySampler`.
"""
function generate_samples!(e::EnergySampler, n::Int, y::Int; niter::Int=100)
if isnothing(e.buffer)
e.buffer = generate_samples(e, n, y; niter=niter)
else
e.buffer = cat(e.buffer, generate_samples(e, n, y; niter=niter), dims=ndims(e.buffer))
end
e.yidx = y
end
"""
Base.rand(sampler::EnergySampler, n::Int=100; retrain=false)
Overloads the `rand` method to randomly draw `n` samples from `EnergySampler`.
"""
function Base.rand(sampler::EnergySampler, n::Int=100; from_buffer=true, niter::Int=100)
ntotal = size(sampler.buffer, 2)
idx = rand(1:ntotal, n)
if from_buffer
X = sampler.buffer[:, idx]
else
X = generate_samples(sampler, n, sampler.yidx; niter=niter)
end
return X
end
"""
get_lowest_energy_sample(sampler::EnergySampler; n::Int=5)
Chooses the samples with the lowest energy (i.e. highest probability) from `EnergySampler`.
"""
function get_lowest_energy_sample(sampler::EnergySampler; n::Int=5)
X = sampler.buffer
model = sampler.model
y = sampler.yidx
x = selectdim(X, ndims(X), energy(sampler.sampler, model, X, y; agg=x -> partialsortperm(x, 1:n)))
return x
end