Skip to content
Snippets Groups Projects
model.jl 6.28 KiB
Newer Older
Pat Alt's avatar
Pat Alt committed
using ChainRules: ignore_derivatives
pat-alt's avatar
pat-alt committed
using ConformalPrediction
using CounterfactualExplanations.Models
using Flux
pat-alt's avatar
pat-alt committed
using MLJBase
Pat Alt's avatar
Pat Alt committed
using MLJEnsembles
using MLJFlux
pat-alt's avatar
pat-alt committed
using MLUtils
using Statistics

Pat Alt's avatar
Pat Alt committed
const CompatibleAtomicModel = Union{<:MLJFlux.MLJFluxProbabilistic,MLJEnsembles.ProbabilisticEnsembleModel{<:MLJFlux.MLJFluxProbabilistic}}

pat-alt's avatar
pat-alt committed
"""
    ConformalModel <: Models.AbstractDifferentiableModel
pat-alt's avatar
pat-alt committed

Constructor for models trained in `Flux.jl`. 
"""
struct ConformalModel <: Models.AbstractDifferentiableModel
pat-alt's avatar
pat-alt committed
    model::ConformalPrediction.ConformalProbabilisticSet
pat-alt's avatar
pat-alt committed
    fitresult::Any
pat-alt's avatar
uh  
pat-alt committed
    likelihood::Union{Nothing,Symbol}
pat-alt's avatar
pat-alt committed
    function ConformalModel(model, fitresult, likelihood)
pat-alt's avatar
uh  
pat-alt committed
        if likelihood  [:classification_binary, :classification_multi] || isnothing(likelihood)
pat-alt's avatar
pat-alt committed
            new(model, fitresult, likelihood)
pat-alt's avatar
pat-alt committed
        else
            throw(
                ArgumentError(
pat-alt's avatar
uh  
pat-alt committed
                    "`likelihood` should either be `nothing` or in `[:classification_binary,:classification_multi]`",
pat-alt's avatar
pat-alt committed
                ),
            )
        end
    end
end

pat-alt's avatar
pat-alt committed
"""
    _get_chains(fitresult)

Private function that extracts the chains from a fitted model.
"""
function _get_chains(fitresult)
Pat Alt's avatar
Pat Alt committed
    
    chains = []

    ignore_derivatives() do 
        if fitresult isa MLJEnsembles.WrappedEnsemble
            _chains = map(res -> res[1], fitresult.ensemble)
        else
            _chains = [fitresult[1]]
        end
        push!(chains, _chains...)
pat-alt's avatar
pat-alt committed
    end
Pat Alt's avatar
Pat Alt committed
    
pat-alt's avatar
pat-alt committed
    return chains
end

"""
    _outdim(fitresult)

Private function that extracts the output dimension from a fitted model.
"""
function _outdim(fitresult)
    if fitresult isa MLJEnsembles.WrappedEnsemble
        outdim = length(fitresult.ensemble[1][2])
    else
        outdim = length(fitresult[2])
    end
    return outdim
end

Pat Alt's avatar
Pat Alt committed
"""
pat-alt's avatar
pat-alt committed
    _get_sampler(model::AbstractFittedModel)
Pat Alt's avatar
Pat Alt committed

Private helper function that extracts the sampler from a fitted model.
"""
pat-alt's avatar
pat-alt committed
function _get_sampler(model::AbstractFittedModel)
Pat Alt's avatar
Pat Alt committed
    _mod = model.model
pat-alt's avatar
pat-alt committed
    if hasfield(typeof(_mod), :model)
        if _mod.model isa MLJEnsembles.EitherEnsembleModel
            _mod = _mod.model
        end
        if _mod.model isa JointEnergyClassifier
            sampler = _mod.model.sampler
        else
            sampler = false
        end
Pat Alt's avatar
Pat Alt committed
    else
        sampler = false
    end
    return sampler
end

"""
pat-alt's avatar
pat-alt committed
    _has_sampler(model::AbstractFittedModel)
Pat Alt's avatar
Pat Alt committed

Private helper function that checks if a fitted model has a sampler.
"""
pat-alt's avatar
pat-alt committed
function _has_sampler(model::AbstractFittedModel)
Pat Alt's avatar
Pat Alt committed
    return !(_get_sampler(model) isa Bool)
end

pat-alt's avatar
pat-alt committed
"""
    ConformalModel(model, fitresult=nothing; likelihood::Union{Nothing,Symbol}=nothing)

Outer constructor for `ConformalModel`. If `fitresult` is not specified, the model is not fitted and `likelihood` is inferred from the model. If `fitresult` is specified, `likelihood` is inferred from the output dimension of the model. If `likelihood` is not specified, it defaults to `:classification_binary`.
"""
pat-alt's avatar
uh  
pat-alt committed
function ConformalModel(model, fitresult=nothing; likelihood::Union{Nothing,Symbol}=nothing)
Pat Alt's avatar
Pat Alt committed

    # Check if model is fitted and infer likelihood:
pat-alt's avatar
uh  
pat-alt committed
    if isnothing(fitresult)
        @info "Conformal Model is not fitted."
    end
Pat Alt's avatar
Pat Alt committed

    # Default to binary classification, if not specified or inferred:
pat-alt's avatar
uh  
pat-alt committed
    if isnothing(likelihood)
        likelihood = :classification_multi
pat-alt's avatar
uh  
pat-alt committed
        @info "Likelihood not specified. Defaulting to $likelihood."
    end
Pat Alt's avatar
Pat Alt committed

    # Construct model:
pat-alt's avatar
pat-alt committed
    testmode!.(_get_chains(fitresult))
pat-alt's avatar
uh  
pat-alt committed
    M = ConformalModel(model, fitresult, likelihood)
    return M
pat-alt's avatar
pat-alt committed
end

"""
    get_logits(f::Flux.Chain, x)

Helper function to return logits in case final layer is an activation function.
"""
get_logits(f::Flux.Chain, x) = f[end] isa Function ? f[1:end-1](x) : f(x)

pat-alt's avatar
pat-alt committed
# Methods
pat-alt's avatar
pat-alt committed
@doc raw"""
    Models.logits(M::ConformalModel, X::AbstractArray)

To keep things consistent with the architecture of `CounterfactualExplanations.jl`, this method computes logits $\beta_i x_i$ (i.e. the linear predictions) for a Conformal Classifier. By default, `MLJ.jl` and `ConformalPrediction.jl` return probabilistic predictions. To get the underlying logits, we invert the softmax function. 

Let $\hat{p}_i$ denote the estimated softmax output for feature $i$. Then in the multi-class case the following formula can be applied:

```math
\beta_i x_i = \log (\hat{p}_i) + \log (\sum_i \exp(\hat{p}_i))
```

For a short derivation, see here: https://math.stackexchange.com/questions/2786600/invert-the-softmax-function. 

In the binary case logits are fed through the sigmoid function instead of softmax, so we need to further adjust as follows,

```math
\beta x = \beta_1 x_1 - \beta_0 x_0
```    

which follows from the derivation here: https://stats.stackexchange.com/questions/233658/softmax-vs-sigmoid-function-in-logistic-classifier
"""
function Models.logits(M::ConformalModel, X::AbstractArray)
Pat Alt's avatar
Pat Alt committed
    
pat-alt's avatar
pat-alt committed
    fitresult = M.fitresult
Pat Alt's avatar
Pat Alt committed

pat-alt's avatar
pat-alt committed
    function predict_logits(fitresult, x)
         = MLUtils.stack(map(chain -> get_logits(chain,x),_get_chains(fitresult))) |> 
            y -> mean(y, dims=ndims(y)) |>
            y -> MLUtils.unstack(y, dims=ndims(y))[1]
        if ndims() == 2
             = []
pat-alt's avatar
pat-alt committed
        end
         = reduce(hcat, )
pat-alt's avatar
pat-alt committed
        if M.likelihood == :classification_binary
pat-alt's avatar
pat-alt committed
             = reduce(hcat, (map(y -> y[2] - y[1], eachcol())))
pat-alt's avatar
pat-alt committed
        end
pat-alt's avatar
pat-alt committed
         = ndims() > 1 ?  : permutedims([])
pat-alt's avatar
pat-alt committed
        return 
pat-alt's avatar
pat-alt committed
    end
Pat Alt's avatar
Pat Alt committed

    yhat = predict_logits(fitresult, X)

pat-alt's avatar
pat-alt committed
    return yhat
pat-alt's avatar
pat-alt committed
end

pat-alt's avatar
uh  
pat-alt committed
"""
    Models.probs(M::ConformalModel, X::AbstractArray)

Returns the estimated probabilities for a Conformal Classifier.
"""
pat-alt's avatar
pat-alt committed
function Models.probs(M::ConformalModel, X::AbstractArray)
Pat Alt's avatar
Pat Alt committed

pat-alt's avatar
pat-alt committed
    if M.likelihood == :classification_binary
pat-alt's avatar
pat-alt committed
        output = σ.(Models.logits(M, X))
pat-alt's avatar
pat-alt committed
    elseif M.likelihood == :classification_multi
pat-alt's avatar
pat-alt committed
        output = softmax(Models.logits(M, X))
pat-alt's avatar
pat-alt committed
    end
    return output
pat-alt's avatar
uh  
pat-alt committed
end

"""
    train(M::ConformalModel, data::CounterfactualData; kwrgs...)

Trains a Conformal Classifier `M` on `data`. 
"""
function Models.train(M::ConformalModel, data::CounterfactualData; kwrgs...)
    X, y = data.X, data.output_encoder.labels
    X = table(permutedims(X))
    conf_model = M.model
    mach = machine(conf_model, X, y)
    fit!(mach; kwrgs...)
    likelihood, _ = CounterfactualExplanations.guess_likelihood(data.output_encoder.y)
    return ConformalModel(mach.model, mach.fitresult, likelihood)
pat-alt's avatar
pat-alt committed
end