diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml index 059c1d6ad34ffb39062c980d2178ca09635bb052..ab30f25db3b1933c4fda157d8e7463c9361310fe 100644 --- a/notebooks/Manifest.toml +++ b/notebooks/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.5" manifest_format = "2.0" -project_hash = "02c72169614894a0688a0f1373776f46a8df1f33" +project_hash = "fcc7af85fdd632057c2b8e76e581f46c71876dbf" [[deps.AbstractFFTs]] deps = ["ChainRulesCore", "LinearAlgebra"] @@ -1285,6 +1285,12 @@ git-tree-sha1 = "37a311b0cd581764fc460f6632e6219dc32f9427" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" version = "0.21.8" +[[deps.MLJEnsembles]] +deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJBase", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatsBase"] +git-tree-sha1 = "bb8a1056b1d8b40f2f27167fc3ef6412a6719fbf" +uuid = "50ed68f4-41fd-4504-931a-ed422449fee0" +version = "0.3.2" + [[deps.MLJFlux]] deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"] git-tree-sha1 = "2ecdce4dd9214789ee1796103d29eaee7619ebd0" diff --git a/notebooks/Project.toml b/notebooks/Project.toml index 771406311a969da68e702d74fa6147f7eef78295..7a43f27719ca60b1c9a628f1940c826fccb97aaf 100644 --- a/notebooks/Project.toml +++ b/notebooks/Project.toml @@ -12,6 +12,7 @@ Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" JointEnergyModels = "48c56d24-211d-4463-bbc0-7a701b291131" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +MLJEnsembles = "50ed68f4-41fd-4504-931a-ed422449fee0" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" diff --git a/notebooks/setup.jl b/notebooks/setup.jl index e2bc7aa0d5db0d05ac3ec01d149f4dde273d8695..d26fc37a09b3570437edbbbe3c081359c6d0c4f1 100644 --- a/notebooks/setup.jl +++ b/notebooks/setup.jl @@ -22,6 +22,7 @@ setup_notebooks = quote using MLDatasets using MLDatasets: convert2image using MLJBase + using MLJEnsembles using MLJFlux using MLUtils using Plots @@ -34,6 +35,6 @@ setup_notebooks = quote Random.seed!(2023) www_path = "www" output_path = "artifacts" - img_height = 300; + img_height = 300 end; \ No newline at end of file diff --git a/notebooks/synthetic.qmd b/notebooks/synthetic.qmd index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..7cf70d5027144ff2eaae6f3a70ae34273150470f 100644 --- a/notebooks/synthetic.qmd +++ b/notebooks/synthetic.qmd @@ -0,0 +1,99 @@ +```{julia} +include("notebooks/setup.jl") +eval(setup_notebooks); +``` + +# Synthetic data + +```{julia} +#| output: false + +# Data: +datasets = Dict( + :linearly_separable => load_linearly_separable(), + :overlapping => load_overlapping(), + :moons => load_moons(), + :circles => load_circles(), + :multi_class => load_multi_class(), +) + +# Hyperparameters: +cvgs = [0.5, 0.75, 0.90, 0.95, 0.99] +temps = [0.05, 0.1, 0.5, 1.0, 2.0] +Λ = [0.1, 1.0, 5.0] + +# Classifiers: +epochs = 100 +link_fun = sigmoid +logreg = NeuralNetworkClassifier(builder=MLJFlux.Linear(σ=link_fun), epochs=epochs) +mlp = NeuralNetworkClassifier(builder=MLJFlux.MLP(hidden=(32,), σ=link_fun), epochs=epochs) +ensmbl = EnsembleModel(model=mlp, n=5) +classifiers = Dict( + :logreg => logreg, + :mlp => mlp, + # :ensmbl => ensmbl, +) + +# Search parameters: +target = 2 +factual = 1 +``` + +```{julia} +results = DataFrame() +for (dataname, data) in datasets + + # Data: + X = table(permutedims(data.X)) + y = data.output_encoder.labels + + for (clf_name, clf) in classifiers, cov in cvgs + + # Classifier and coverage: + conf_model = conformal_model(clf; method=:simple_inductive, coverage=cov) + mach = machine(conf_model, X, y) + fit!(mach) + M = CCE.ConformalModel(mach.model, mach.fitresult) + + # Set up CCE: + yhat = predict_label(M, data) + factual_label = data.y_levels[factual] + target_label = data.y_levels[target] + x = select_factual(data,rand(findall(yhat .== factual_label))) + + for λ in Λ, temp in temps + + # CCE for given classifier, coverage, temperature and λ: + generator = CCEGenerator(temp=temp, λ=λ) + @assert predict_label(M, data, x) != target_label + ce = try + generate_counterfactual( + x, target_label, data, M, generator; + initialization=:identity, + ) + catch + missing + end + + _results = DataFrame( + dataset = dataname, + classifier = clf_name, + coverage = cov, + temperature = temp, + λ = λ, + counterfactual = ce, + factual = factual_label, + target = target_label, + ) + append!(results, _results) + + end + + end + +end +``` + +```{julia} + +``` diff --git a/src/CCE.jl b/src/CCE.jl index da9ffd9ec03684f171814e9fb58c107771caf2e7..1b678a8121a413e2282e4a0096bdd67a02b0d605 100644 --- a/src/CCE.jl +++ b/src/CCE.jl @@ -8,9 +8,10 @@ include("penalties.jl") include("losses.jl") include("generator.jl") include("sampling.jl") -# include("ConformalGenerator.jl") using MLJFlux -MLJFlux.reformat(X, ::Type{<:AbstractMatrix}) = permutedims(X) +MLJFlux.reformat(X, ::Type{<:AbstractMatrix}) = X' -end +export CCEGenerator, EnergySampler, set_size_penalty, distance_from_energy + +end \ No newline at end of file diff --git a/src/ConformalGenerator.jl b/src/ConformalGenerator.jl deleted file mode 100644 index b169399b47a65d4fe0fff8c35087bf68c990017a..0000000000000000000000000000000000000000 --- a/src/ConformalGenerator.jl +++ /dev/null @@ -1,149 +0,0 @@ -using CategoricalArrays -using CounterfactualExplanations -using CounterfactualExplanations.Generators -using CounterfactualExplanations.Models: binary_to_onehot -using CounterfactualExplanations.Objectives -using Flux -using LinearAlgebra -using Parameters -using SliceMap -using Statistics - -mutable struct ConformalGenerator <: AbstractGradientBasedGenerator - loss::Union{Nothing,Function} # loss function - complexity::Function # complexity function - λ::Union{AbstractFloat,AbstractVector} # strength of penalty - decision_threshold::Union{Nothing,AbstractFloat} - opt::Flux.Optimise.AbstractOptimiser # optimizer - τ::AbstractFloat # tolerance for convergence - κ::Real - temp::Real -end - -# API streamlining: -@with_kw struct ConformalGeneratorParams - opt::Flux.Optimise.AbstractOptimiser = Descent() - τ::AbstractFloat = 1e-3 - κ::Real = 1.0 - temp::Real = 0.5 -end - -""" - ConformalGenerator(; - loss::Union{Nothing,Function}=conformal_training_loss, - complexity::Function=norm, - λ::Union{AbstractFloat,AbstractVector}=[0.1, 1.0], - decision_threshold=nothing, - kwargs... - ) - -An outer constructor method that instantiates a generic generator. - -# Examples -```julia-repl -generator = ConformalGenerator() -``` -""" -function ConformalGenerator(; - loss::Union{Nothing,Function}=conformal_training_loss, - complexity::Function=norm, - λ::Union{AbstractFloat,AbstractVector}=[0.1, 1.0], - decision_threshold=nothing, - kwargs... -) - params = ConformalGeneratorParams(; kwargs...) - ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ, params.κ, params.temp) -end - -@doc raw""" - conformal_training_loss(counterfactual_explanation::AbstractCounterfactualExplanation; kwargs...) - -A configurable classification loss function for Conformal Predictors. -""" -function conformal_training_loss(counterfactual_explanation::AbstractCounterfactualExplanation; kwargs...) - conf_model = counterfactual_explanation.M.model - fitresult = counterfactual_explanation.M.fitresult - generator = counterfactual_explanation.generator - temp = hasfield(typeof(generator), :temp) ? generator.temp : nothing - K = length(counterfactual_explanation.data.y_levels) - X = CounterfactualExplanations.decode_state(counterfactual_explanation) - y = counterfactual_explanation.target_encoded[:,:,1] - if counterfactual_explanation.M.likelihood == :classification_binary - y = binary_to_onehot(y) - end - y = permutedims(y) - loss = SliceMap.slicemap(X, dims=(1, 2)) do x - x = Matrix(x) - ConformalPrediction.classification_loss( - conf_model, fitresult, x, y; - temp=temp, - loss_matrix=Float32.(ones(K,K)) - ) - end - loss = mean(loss) - return loss -end - -""" - set_size_penalty( - generator::ConformalGenerator, - counterfactual_explanation::AbstractCounterfactualExplanation, - ) - -Additional penalty for ConformalGenerator. -""" -function set_size_penalty( - generator::ConformalGenerator, - counterfactual_explanation::AbstractCounterfactualExplanation, -) - - conf_model = counterfactual_explanation.M.model - fitresult = counterfactual_explanation.M.fitresult - X = CounterfactualExplanations.decode_state(counterfactual_explanation) - loss = SliceMap.slicemap(X, dims=(1, 2)) do x - x = Matrix(x) - ConformalPrediction.smooth_size_loss( - conf_model, fitresult, x; - κ=generator.κ, - temp=generator.temp - ) - end - loss = mean(loss) - - return loss - -end - -# Complexity: -""" - h(generator::AbstractGenerator, counterfactual_explanation::AbstractCounterfactualExplanation) - -The default method to apply the generator complexity penalty to the current counterfactual state for any generator. -""" -function Generators.h( - generator::ConformalGenerator, - counterfactual_explanation::AbstractCounterfactualExplanation, -) - - # Distance from factual: - dist_ = generator.complexity( - counterfactual_explanation.x .- - CounterfactualExplanations.decode_state(counterfactual_explanation), - ) - - # Set size penalty: - in_target_domain = all(target_probs(counterfactual_explanation) .>= 0.5) - if in_target_domain - Ω = set_size_penalty(generator, counterfactual_explanation) - else - Ω = 0 - end - - if length(generator.λ) == 1 - penalty = generator.λ * (dist_ .+ Ω) - else - penalty = generator.λ[1] * dist_ .+ generator.λ[2] * Ω - end - - return penalty -end diff --git a/src/generator.jl b/src/generator.jl index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..5dc55065db336e335d0aa7be2c6606843925a22e 100644 --- a/src/generator.jl +++ b/src/generator.jl @@ -0,0 +1,11 @@ +using CounterfactualExplanations.Objectives + +"Constructor for `CCEGenerator`." +function CCEGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], κ::Real=0.0, temp::Real=0.05, kwargs...) + function _set_size_penalty(ce::AbstractCounterfactualExplanation) + return CCE.set_size_penalty(ce; κ=κ, temp=temp) + end + _penalties = [Objectives.distance_l2, _set_size_penalty] + λ = λ isa AbstractFloat ? [0.0, λ] : λ + return Generator(; penalty=_penalties, λ=λ, kwargs...) +end \ No newline at end of file diff --git a/src/model.jl b/src/model.jl index dcf17b4d58afc237c305cacd20a69028e3c29c68..28dc1f412d3f923334a5a15dc06d63c60ac8440a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -30,13 +30,24 @@ end # Outer constructor method: function ConformalModel(model, fitresult=nothing; likelihood::Union{Nothing,Symbol}=nothing) + + # Check if model is fitted and infer likelihood: if isnothing(fitresult) @info "Conformal Model is not fitted." + else + outdim = length(fitresult[2]) + _likelihood = outdim == 2 ? :classification_binary : :classification_multi + @assert likelihood == _likelihood || isnothing(likelihood) "Specification of `likelihood` does not match the output dimension of the model." + likelihood = _likelihood end + + # Default to binary classification, if not specified or inferred: if isnothing(likelihood) likelihood = :classification_binary @info "Likelihood not specified. Defaulting to $likelihood." end + + # Construct model: M = ConformalModel(model, fitresult, likelihood) return M end