Skip to content
Snippets Groups Projects
Commit b607e478 authored by pat-alt's avatar pat-alt
Browse files

uh

parent 356d7aa7
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,7 @@
julia_version = "1.8.5"
manifest_format = "2.0"
project_hash = "7dee1a21a38b8ebe5b8194926ba2e4b4df21760e"
project_hash = "a3f5e7cf561325560c244a222fe34c51dfd04226"
[[deps.AbstractFFTs]]
deps = ["ChainRulesCore", "LinearAlgebra"]
......
......@@ -8,6 +8,7 @@ ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f"
CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
......
```{julia}
using CCE
using ConformalPrediction
using CounterfactualExplanations
using CounterfactualExplanations.Data
using MLJFlux
```
```{julia}
counterfactual_data = load_linearly_separable()
```
```{julia}
clf = NeuralNetworkClassifier()
```
```{julia}
X = table(permutedims(counterfactual_data.X))
y = counterfactual_data.output_encoder.labels
conf_model = conformal_model(clf; method=:simple_inductive)
mach = machine(conf_model, X, y)
fit!(mach)
```
......@@ -6,14 +6,14 @@ using SliceMap
using Statistics
"""
Models.ConformalModel <: AbstractDifferentiableJuliaModel
ConformalModel <: Models.AbstractDifferentiableJuliaModel
Constructor for models trained in `Flux.jl`.
"""
struct Models.ConformalModel <: AbstractDifferentiableJuliaModel
struct ConformalModel <: Models.AbstractDifferentiableJuliaModel
model::ConformalPrediction.ConformalProbabilisticSet
likelihood::Symbol
function FluxModel(model, likelihood)
function ConformalModel(model, likelihood)
if likelihood [:classification_binary, :classification_multi]
new(model, likelihood)
else
......@@ -27,16 +27,16 @@ struct Models.ConformalModel <: AbstractDifferentiableJuliaModel
end
# Outer constructor method:
function Models.ConformalModel(model; likelihood::Symbol=:classification_binary)
Models.ConformalModel(model, likelihood)
function ConformalModel(model; likelihood::Symbol=:classification_binary)
ConformalModel(model, likelihood)
end
# Methods
function logits(M::Models.ConformalModel, X::AbstractArray)
function logits(M::ConformalModel, X::AbstractArray)
return SliceMap.slicemap(x -> M.model(x), X, dims=(1, 2))
end
function probs(M::Models.ConformalModel, X::AbstractArray)
function probs(M::ConformalModel, X::AbstractArray)
if M.likelihood == :classification_binary
output = σ.(logits(M, X))
elseif M.likelihood == :classification_multi
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment