diff --git a/Manifest.toml b/Manifest.toml index 9434115488ec05f1ed00e2d6098137a8214bf046..ad348f3b257ccb5a11e8de7adf4429a6d9588a3d 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -286,7 +286,7 @@ uuid = "ed09eef8-17a6-5b46-8889-db040fac31e3" version = "0.3.2" [[deps.ConformalPrediction]] -deps = ["CategoricalArrays", "ChainRules", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "NaturalSort", "Plots", "StatsBase"] +deps = ["CategoricalArrays", "ChainRules", "Flux", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "NaturalSort", "Plots", "StatsBase"] path = "../ConformalPrediction.jl" uuid = "98bfc277-1877-43dc-819b-a3e38c30242f" version = "0.1.6" diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml index 4c077ed9a0fb768eb3181835b8c94abd02cd8374..899c4ea1bfe9e8e56c0aa885394b394a76667bf0 100644 --- a/notebooks/Manifest.toml +++ b/notebooks/Manifest.toml @@ -348,7 +348,7 @@ uuid = "ed09eef8-17a6-5b46-8889-db040fac31e3" version = "0.3.2" [[deps.ConformalPrediction]] -deps = ["CategoricalArrays", "ChainRules", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "NaturalSort", "Plots", "StatsBase"] +deps = ["CategoricalArrays", "ChainRules", "Flux", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "NaturalSort", "Plots", "StatsBase"] path = "../../ConformalPrediction.jl" uuid = "98bfc277-1877-43dc-819b-a3e38c30242f" version = "0.1.6" diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index b4fb8b774215fbe9c035a9d9778bad5ae7ded2a4..b4efbc521ca2758c7437efb9468b2d3e76b4be43 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -33,7 +33,7 @@ First, let's create a couple of image classifier architectures: epochs = 100 batch_size = minimum([Int(round(n_obs/10)), 128]) n_hidden = 128 -activation = Flux.swish +activation = Flux.relu builder = MLJFlux.@builder Flux.Chain( Dense(n_in, n_hidden, activation), @@ -56,15 +56,20 @@ builder = MLJFlux.@builder Flux.Chain( # ), # σ=activation # ) -α = [1.0,1.0,1e-2] +α = [1.0,1.0,5e-3] # Simple MLP: mlp = NeuralNetworkClassifier( builder=builder, epochs=epochs, batch_size=batch_size, + finaliser=x -> x, + loss=Flux.Losses.logitcrossentropy, ) +# Deep Ensemble: +mlp_ens = EnsembleModel(model=mlp, n=5) + # Joint Energy Model: ð’Ÿx = Uniform(-1,1) ð’Ÿy = Categorical(ones(output_dim) ./ output_dim) @@ -88,39 +93,49 @@ jem = JointEnergyClassifier( epochs=epochs, ) -# Deep Ensemble: -mlp_ens = EnsembleModel(model=mlp, n=50) +# Deep Ensemble of Joint Energy Models: +jem_ens = EnsembleModel(model=jem, n=5) ``` ```{julia} cov = .95 -conf_model = conformal_model(jem; method=:simple_inductive, coverage=cov) +conf_model = conformal_model(jem_ens; method=:adaptive_inductive, coverage=cov) mach = machine(conf_model, X, labels) fit!(mach) M = ECCCo.ConformalModel(mach.model, mach.fitresult) ``` ```{julia} -if mach.model.model isa JointEnergyModels.JointEnergyClassifier - jem = mach.model.model.jem - n_iter = 200 - _w = 1500 - plts = [] - neach = 10 - for i in 1:10 - x = jem.sampler(jem.chain, jem.sampling_rule; niter=n_iter, n_samples=neach, y=i) - plts_i = [] - for j in 1:size(x, 2) - xj = x[:,j] - xj = reshape(xj, (n_digits, n_digits)) - plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)] - end - plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10)) - plts = [plts..., plt] +if mach.model.model isa JointEnergyClassifier + sampler = mach.model.model.jem.sampler +else + K = length(counterfactual_data.y_levels) + input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1)) + ð’Ÿx = Uniform(extrema(counterfactual_data.X)...) + ð’Ÿy = Categorical(ones(K) ./ K) + sampler = ConditionalSampler(ð’Ÿx, ð’Ÿy; input_size=input_size) +end +opt = ImproperSGLD() +f(x) = logits(M, x) + +n_iter = 200 +_w = 1500 +plts = [] +neach = 10 +for i in 1:10 + x = sampler(f, opt; niter=n_iter, n_samples=neach, y=i) + plts_i = [] + for j in 1:size(x, 2) + xj = x[:,j] + xj = reshape(xj, (n_digits, n_digits)) + plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)] end - plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1)) - display(plt) + plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10)) + plts = [plts..., plt] end +plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1)) +display(plt) + ``` ```{julia} @@ -134,12 +149,12 @@ println("F1 score (test): $(round(f1,digits=3))") Random.seed!(1234) # Set up search: -factual_label = 9 +factual_label = 4 x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) -target = 4 +target = 5 factual = predict_label(M, counterfactual_data, x)[1] -γ = 0.5 -T = 250 +γ = 0.9 +T = 100 η=1.0 @@ -149,6 +164,7 @@ ce_wachter = generate_counterfactual( x, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, + converge_when=:generator_conditions, ) generator = GreedyGenerator(η=η) @@ -156,6 +172,7 @@ ce_jsma = generate_counterfactual( x, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, + converge_when=:generator_conditions, ) # ECCCo: @@ -166,6 +183,7 @@ temp=0.1 generator = ECCCoGenerator( λ=λ, temp=temp, + opt=Flux.Optimise.Adam(0.01), ) ce_conformal = generate_counterfactual( x, target, counterfactual_data, M, generator; diff --git a/paper/paper.pdf b/paper/paper.pdf index eddb381562b5e51c254fb4629f64e5cc0707647a..56e5e6bc610e19203b81094a4d550becf73727cd 100644 Binary files a/paper/paper.pdf and b/paper/paper.pdf differ diff --git a/paper/paper.tex b/paper/paper.tex index 6ee1a43bf68b58fb0a7c4c743d1a728573ad951c..3f271797ba751480aa2f429098c24c45c55b9aa3 100644 --- a/paper/paper.tex +++ b/paper/paper.tex @@ -45,7 +45,7 @@ \newtheorem{definition}{Definition}[section] -\title{ECCCos from the Black Box: Letting Models speak for Themselves} +\title{ECCoEs from the Black Box: Letting Models speak for Themselves} % The \author macro works with any number of authors. There are two commands @@ -254,7 +254,7 @@ where $\hat{q}$ denotes the $(1-\alpha)$-quantile of $\mathcal{S}$ and $\alpha$ Observe from Equation~\ref{eq:scp} that Conformal Prediction works on an instance-level basis, much like Counterfactual Explanations are local. The prediction set for an individual instance $\mathbf{x}_i$ depends only on the characteristics of that sample and the specified error rate. Intuitively, the set is more likely to include multiple labels for samples that are difficult to classify, so the set size is indicative of predictive uncertainty. To see why this effect is exacerbated by small choices for $\alpha$ consider the case of $\alpha=0$, which requires that the true label is covered by the prediction set with probability equal to one. -\subsection{Conformal Counterfactual Explanations} +\subsection{ECCoE: Energy-Constrained Conformal Counterfactual Explanation} The fact that conformal classifiers produce set-valued predictions introduces a challenge: it is not immediately obvious how to use such classifiers in the context of gradient-based counterfactual search. Put differently, it is not clear how to use prediction sets in Equation~\ref{eq:general}. Fortunately, \citet{stutz2022learning} have recently proposed a framework for Conformal Training that also hinges on differentiability. Specifically, they show how Stochastic Gradient Descent can be used to train classifiers not only for the discriminative task but also for additional objectives related to Conformal Prediction. One such objective is \textit{efficiency}: for a given target error rate $alpha$, the efficiency of a conformal classifier improves as its average prediction set size decreases. To this end, the authors introduce a smooth set size penalty, diff --git a/src/model.jl b/src/model.jl index cc7cde3c24186017c91040107d06e1623f187089..38188816149952099159cc2b29334622c958b727 100644 --- a/src/model.jl +++ b/src/model.jl @@ -32,14 +32,46 @@ struct ConformalModel <: Models.AbstractDifferentiableJuliaModel end end -# Outer constructor method: +""" + _get_chains(fitresult) + +Private function that extracts the chains from a fitted model. +""" +function _get_chains(fitresult) + if fitresult isa MLJEnsembles.WrappedEnsemble + chains = map(res -> res[1], fitresult.ensemble) + else + chains = [fitresult[1]] + end + return chains +end + +""" + _outdim(fitresult) + +Private function that extracts the output dimension from a fitted model. +""" +function _outdim(fitresult) + if fitresult isa MLJEnsembles.WrappedEnsemble + outdim = length(fitresult.ensemble[1][2]) + else + outdim = length(fitresult[2]) + end + return outdim +end + +""" + ConformalModel(model, fitresult=nothing; likelihood::Union{Nothing,Symbol}=nothing) + +Outer constructor for `ConformalModel`. If `fitresult` is not specified, the model is not fitted and `likelihood` is inferred from the model. If `fitresult` is specified, `likelihood` is inferred from the output dimension of the model. If `likelihood` is not specified, it defaults to `:classification_binary`. +""" function ConformalModel(model, fitresult=nothing; likelihood::Union{Nothing,Symbol}=nothing) # Check if model is fitted and infer likelihood: if isnothing(fitresult) @info "Conformal Model is not fitted." else - outdim = length(fitresult[2]) + outdim = _outdim(fitresult) _likelihood = outdim == 2 ? :classification_binary : :classification_multi @assert likelihood == _likelihood || isnothing(likelihood) "Specification of `likelihood` does not match the output dimension of the model." likelihood = _likelihood @@ -52,7 +84,7 @@ function ConformalModel(model, fitresult=nothing; likelihood::Union{Nothing,Symb end # Construct model: - testmode!(fitresult[1]) + testmode!.(_get_chains(fitresult)) M = ConformalModel(model, fitresult, likelihood) return M end @@ -82,7 +114,9 @@ which follows from the derivation here: https://stats.stackexchange.com/question function Models.logits(M::ConformalModel, X::AbstractArray) fitresult = M.fitresult function predict_logits(fitresult, x) - pÌ‚ = fitresult[1](x) + pÌ‚ = MLUtils.stack(map(chain -> chain(x),_get_chains(fitresult))) |> + p -> mean(p, dims=ndims(p)) |> + p -> MLUtils.unstack(p, dims=ndims(p))[1] if ndims(pÌ‚) == 2 pÌ‚ = [pÌ‚] end diff --git a/src/sampling.jl b/src/sampling.jl index 023c3d87fd7a138fc21431f4fd603aa19ace771e..3cd4b60470fc3d42e37300ddfe5bbcf903611c3b 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -92,9 +92,10 @@ Generates `n` samples from `EnergySampler` for conditioning value `y`. function generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100) # Generate samples: - chain = e.model.fitresult[1] + # chain = e.model.fitresult[1] + f(x) = logits(e.model, x) rule = e.opt - xsamples = e.sampler(chain, rule; niter=niter, n_samples=n, y=y) + xsamples = e.sampler(f, rule; niter=niter, n_samples=n, y=y) return xsamples end diff --git a/www/eccco_mnist.png b/www/eccco_mnist.png index edd2100307beb05f2b55511561bad5a4df0ad9f2..8b0a8a14d08b2310cfafe1ce68f55891d6ba1a33 100644 Binary files a/www/eccco_mnist.png and b/www/eccco_mnist.png differ