Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
using ConformalPrediction
using CounterfactualExplanations.Models
using Flux
using MLUtils
using SliceMap
using Statistics
"""
Models.ConformalModel <: AbstractDifferentiableJuliaModel
Constructor for models trained in `Flux.jl`.
"""
struct Models.ConformalModel <: AbstractDifferentiableJuliaModel
model::ConformalPrediction.ConformalProbabilisticSet
likelihood::Symbol
function FluxModel(model, likelihood)
if likelihood ∈ [:classification_binary, :classification_multi]
new(model, likelihood)
else
throw(
ArgumentError(
"`type` should be in `[:classification_binary,:classification_multi]`",
),
)
end
end
end
# Outer constructor method:
function Models.ConformalModel(model; likelihood::Symbol=:classification_binary)
Models.ConformalModel(model, likelihood)
end
# Methods
function logits(M::Models.ConformalModel, X::AbstractArray)
return SliceMap.slicemap(x -> M.model(x), X, dims=(1, 2))
end
function probs(M::Models.ConformalModel, X::AbstractArray)
if M.likelihood == :classification_binary
output = σ.(logits(M, X))
elseif M.likelihood == :classification_multi
output = softmax(logits(M, X))
end
return output
end