Skip to content
Snippets Groups Projects
Commit 3f1c82aa authored by Pat Alt's avatar Pat Alt
Browse files

work on granular results

parent 232dd4d4
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,7 @@
julia_version = "1.8.5"
manifest_format = "2.0"
project_hash = "02c72169614894a0688a0f1373776f46a8df1f33"
project_hash = "fcc7af85fdd632057c2b8e76e581f46c71876dbf"
[[deps.AbstractFFTs]]
deps = ["ChainRulesCore", "LinearAlgebra"]
......@@ -1285,6 +1285,12 @@ git-tree-sha1 = "37a311b0cd581764fc460f6632e6219dc32f9427"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
version = "0.21.8"
[[deps.MLJEnsembles]]
deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJBase", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatsBase"]
git-tree-sha1 = "bb8a1056b1d8b40f2f27167fc3ef6412a6719fbf"
uuid = "50ed68f4-41fd-4504-931a-ed422449fee0"
version = "0.3.2"
[[deps.MLJFlux]]
deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"]
git-tree-sha1 = "2ecdce4dd9214789ee1796103d29eaee7619ebd0"
......
......@@ -12,6 +12,7 @@ Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
JointEnergyModels = "48c56d24-211d-4463-bbc0-7a701b291131"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJEnsembles = "50ed68f4-41fd-4504-931a-ed422449fee0"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
......
......@@ -22,6 +22,7 @@ setup_notebooks = quote
using MLDatasets
using MLDatasets: convert2image
using MLJBase
using MLJEnsembles
using MLJFlux
using MLUtils
using Plots
......@@ -34,6 +35,6 @@ setup_notebooks = quote
Random.seed!(2023)
www_path = "www"
output_path = "artifacts"
img_height = 300;
img_height = 300
end;
\ No newline at end of file
```{julia}
include("notebooks/setup.jl")
eval(setup_notebooks);
```
# Synthetic data
```{julia}
#| output: false
# Data:
datasets = Dict(
:linearly_separable => load_linearly_separable(),
:overlapping => load_overlapping(),
:moons => load_moons(),
:circles => load_circles(),
:multi_class => load_multi_class(),
)
# Hyperparameters:
cvgs = [0.5, 0.75, 0.90, 0.95, 0.99]
temps = [0.05, 0.1, 0.5, 1.0, 2.0]
Λ = [0.1, 1.0, 5.0]
# Classifiers:
epochs = 100
link_fun = sigmoid
logreg = NeuralNetworkClassifier(builder=MLJFlux.Linear(σ=link_fun), epochs=epochs)
mlp = NeuralNetworkClassifier(builder=MLJFlux.MLP(hidden=(32,), σ=link_fun), epochs=epochs)
ensmbl = EnsembleModel(model=mlp, n=5)
classifiers = Dict(
:logreg => logreg,
:mlp => mlp,
# :ensmbl => ensmbl,
)
# Search parameters:
target = 2
factual = 1
```
```{julia}
results = DataFrame()
for (dataname, data) in datasets
# Data:
X = table(permutedims(data.X))
y = data.output_encoder.labels
for (clf_name, clf) in classifiers, cov in cvgs
# Classifier and coverage:
conf_model = conformal_model(clf; method=:simple_inductive, coverage=cov)
mach = machine(conf_model, X, y)
fit!(mach)
M = CCE.ConformalModel(mach.model, mach.fitresult)
# Set up CCE:
yhat = predict_label(M, data)
factual_label = data.y_levels[factual]
target_label = data.y_levels[target]
x = select_factual(data,rand(findall(yhat .== factual_label)))
for λ in Λ, temp in temps
# CCE for given classifier, coverage, temperature and λ:
generator = CCEGenerator(temp=temp, λ=λ)
@assert predict_label(M, data, x) != target_label
ce = try
generate_counterfactual(
x, target_label, data, M, generator;
initialization=:identity,
)
catch
missing
end
_results = DataFrame(
dataset = dataname,
classifier = clf_name,
coverage = cov,
temperature = temp,
λ = λ,
counterfactual = ce,
factual = factual_label,
target = target_label,
)
append!(results, _results)
end
end
end
```
```{julia}
```
......@@ -8,9 +8,10 @@ include("penalties.jl")
include("losses.jl")
include("generator.jl")
include("sampling.jl")
# include("ConformalGenerator.jl")
using MLJFlux
MLJFlux.reformat(X, ::Type{<:AbstractMatrix}) = permutedims(X)
MLJFlux.reformat(X, ::Type{<:AbstractMatrix}) = X'
end
export CCEGenerator, EnergySampler, set_size_penalty, distance_from_energy
end
\ No newline at end of file
using CategoricalArrays
using CounterfactualExplanations
using CounterfactualExplanations.Generators
using CounterfactualExplanations.Models: binary_to_onehot
using CounterfactualExplanations.Objectives
using Flux
using LinearAlgebra
using Parameters
using SliceMap
using Statistics
mutable struct ConformalGenerator <: AbstractGradientBasedGenerator
loss::Union{Nothing,Function} # loss function
complexity::Function # complexity function
λ::Union{AbstractFloat,AbstractVector} # strength of penalty
decision_threshold::Union{Nothing,AbstractFloat}
opt::Flux.Optimise.AbstractOptimiser # optimizer
τ::AbstractFloat # tolerance for convergence
κ::Real
temp::Real
end
# API streamlining:
@with_kw struct ConformalGeneratorParams
opt::Flux.Optimise.AbstractOptimiser = Descent()
τ::AbstractFloat = 1e-3
κ::Real = 1.0
temp::Real = 0.5
end
"""
ConformalGenerator(;
loss::Union{Nothing,Function}=conformal_training_loss,
complexity::Function=norm,
λ::Union{AbstractFloat,AbstractVector}=[0.1, 1.0],
decision_threshold=nothing,
kwargs...
)
An outer constructor method that instantiates a generic generator.
# Examples
```julia-repl
generator = ConformalGenerator()
```
"""
function ConformalGenerator(;
loss::Union{Nothing,Function}=conformal_training_loss,
complexity::Function=norm,
λ::Union{AbstractFloat,AbstractVector}=[0.1, 1.0],
decision_threshold=nothing,
kwargs...
)
params = ConformalGeneratorParams(; kwargs...)
ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ, params.κ, params.temp)
end
@doc raw"""
conformal_training_loss(counterfactual_explanation::AbstractCounterfactualExplanation; kwargs...)
A configurable classification loss function for Conformal Predictors.
"""
function conformal_training_loss(counterfactual_explanation::AbstractCounterfactualExplanation; kwargs...)
conf_model = counterfactual_explanation.M.model
fitresult = counterfactual_explanation.M.fitresult
generator = counterfactual_explanation.generator
temp = hasfield(typeof(generator), :temp) ? generator.temp : nothing
K = length(counterfactual_explanation.data.y_levels)
X = CounterfactualExplanations.decode_state(counterfactual_explanation)
y = counterfactual_explanation.target_encoded[:,:,1]
if counterfactual_explanation.M.likelihood == :classification_binary
y = binary_to_onehot(y)
end
y = permutedims(y)
loss = SliceMap.slicemap(X, dims=(1, 2)) do x
x = Matrix(x)
ConformalPrediction.classification_loss(
conf_model, fitresult, x, y;
temp=temp,
loss_matrix=Float32.(ones(K,K))
)
end
loss = mean(loss)
return loss
end
"""
set_size_penalty(
generator::ConformalGenerator,
counterfactual_explanation::AbstractCounterfactualExplanation,
)
Additional penalty for ConformalGenerator.
"""
function set_size_penalty(
generator::ConformalGenerator,
counterfactual_explanation::AbstractCounterfactualExplanation,
)
conf_model = counterfactual_explanation.M.model
fitresult = counterfactual_explanation.M.fitresult
X = CounterfactualExplanations.decode_state(counterfactual_explanation)
loss = SliceMap.slicemap(X, dims=(1, 2)) do x
x = Matrix(x)
ConformalPrediction.smooth_size_loss(
conf_model, fitresult, x;
κ=generator.κ,
temp=generator.temp
)
end
loss = mean(loss)
return loss
end
# Complexity:
"""
h(generator::AbstractGenerator, counterfactual_explanation::AbstractCounterfactualExplanation)
The default method to apply the generator complexity penalty to the current counterfactual state for any generator.
"""
function Generators.h(
generator::ConformalGenerator,
counterfactual_explanation::AbstractCounterfactualExplanation,
)
# Distance from factual:
dist_ = generator.complexity(
counterfactual_explanation.x .-
CounterfactualExplanations.decode_state(counterfactual_explanation),
)
# Set size penalty:
in_target_domain = all(target_probs(counterfactual_explanation) .>= 0.5)
if in_target_domain
Ω = set_size_penalty(generator, counterfactual_explanation)
else
Ω = 0
end
if length(generator.λ) == 1
penalty = generator.λ * (dist_ .+ Ω)
else
penalty = generator.λ[1] * dist_ .+ generator.λ[2] * Ω
end
return penalty
end
using CounterfactualExplanations.Objectives
"Constructor for `CCEGenerator`."
function CCEGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], κ::Real=0.0, temp::Real=0.05, kwargs...)
function _set_size_penalty(ce::AbstractCounterfactualExplanation)
return CCE.set_size_penalty(ce; κ=κ, temp=temp)
end
_penalties = [Objectives.distance_l2, _set_size_penalty]
λ = λ isa AbstractFloat ? [0.0, λ] : λ
return Generator(; penalty=_penalties, λ=λ, kwargs...)
end
\ No newline at end of file
......@@ -30,13 +30,24 @@ end
# Outer constructor method:
function ConformalModel(model, fitresult=nothing; likelihood::Union{Nothing,Symbol}=nothing)
# Check if model is fitted and infer likelihood:
if isnothing(fitresult)
@info "Conformal Model is not fitted."
else
outdim = length(fitresult[2])
_likelihood = outdim == 2 ? :classification_binary : :classification_multi
@assert likelihood == _likelihood || isnothing(likelihood) "Specification of `likelihood` does not match the output dimension of the model."
likelihood = _likelihood
end
# Default to binary classification, if not specified or inferred:
if isnothing(likelihood)
likelihood = :classification_binary
@info "Likelihood not specified. Defaulting to $likelihood."
end
# Construct model:
M = ConformalModel(model, fitresult, likelihood)
return M
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment