Skip to content
Snippets Groups Projects
model.jl 1.27 KiB
Newer Older
pat-alt's avatar
pat-alt committed
using ConformalPrediction
using CounterfactualExplanations.Models
using Flux
using MLUtils
using SliceMap
using Statistics

"""
    Models.ConformalModel <: AbstractDifferentiableJuliaModel

Constructor for models trained in `Flux.jl`. 
"""
struct Models.ConformalModel <: AbstractDifferentiableJuliaModel
    model::ConformalPrediction.ConformalProbabilisticSet
    likelihood::Symbol
    function FluxModel(model, likelihood)
        if likelihood  [:classification_binary, :classification_multi]
            new(model, likelihood)
        else
            throw(
                ArgumentError(
                    "`type` should be in `[:classification_binary,:classification_multi]`",
                ),
            )
        end
    end
end

# Outer constructor method:
function Models.ConformalModel(model; likelihood::Symbol=:classification_binary)
    Models.ConformalModel(model, likelihood)
end

# Methods
function logits(M::Models.ConformalModel, X::AbstractArray)
    return SliceMap.slicemap(x -> M.model(x), X, dims=(1, 2))
end

function probs(M::Models.ConformalModel, X::AbstractArray)
    if M.likelihood == :classification_binary
        output = σ.(logits(M, X))
    elseif M.likelihood == :classification_multi
        output = softmax(logits(M, X))
    end
    return output
end