Skip to content
Snippets Groups Projects
Commit b607e478 authored by pat-alt's avatar pat-alt
Browse files

uh

parent 356d7aa7
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
julia_version = "1.8.5" julia_version = "1.8.5"
manifest_format = "2.0" manifest_format = "2.0"
project_hash = "7dee1a21a38b8ebe5b8194926ba2e4b4df21760e" project_hash = "a3f5e7cf561325560c244a222fe34c51dfd04226"
[[deps.AbstractFFTs]] [[deps.AbstractFFTs]]
deps = ["ChainRulesCore", "LinearAlgebra"] deps = ["ChainRulesCore", "LinearAlgebra"]
......
...@@ -8,6 +8,7 @@ ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f" ...@@ -8,6 +8,7 @@ ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f"
CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
......
```{julia}
using CCE
using ConformalPrediction
using CounterfactualExplanations
using CounterfactualExplanations.Data
using MLJFlux
```
```{julia}
counterfactual_data = load_linearly_separable()
```
```{julia}
clf = NeuralNetworkClassifier()
```
```{julia}
X = table(permutedims(counterfactual_data.X))
y = counterfactual_data.output_encoder.labels
conf_model = conformal_model(clf; method=:simple_inductive)
mach = machine(conf_model, X, y)
fit!(mach)
```
...@@ -6,14 +6,14 @@ using SliceMap ...@@ -6,14 +6,14 @@ using SliceMap
using Statistics using Statistics
""" """
Models.ConformalModel <: AbstractDifferentiableJuliaModel ConformalModel <: Models.AbstractDifferentiableJuliaModel
Constructor for models trained in `Flux.jl`. Constructor for models trained in `Flux.jl`.
""" """
struct Models.ConformalModel <: AbstractDifferentiableJuliaModel struct ConformalModel <: Models.AbstractDifferentiableJuliaModel
model::ConformalPrediction.ConformalProbabilisticSet model::ConformalPrediction.ConformalProbabilisticSet
likelihood::Symbol likelihood::Symbol
function FluxModel(model, likelihood) function ConformalModel(model, likelihood)
if likelihood [:classification_binary, :classification_multi] if likelihood [:classification_binary, :classification_multi]
new(model, likelihood) new(model, likelihood)
else else
...@@ -27,16 +27,16 @@ struct Models.ConformalModel <: AbstractDifferentiableJuliaModel ...@@ -27,16 +27,16 @@ struct Models.ConformalModel <: AbstractDifferentiableJuliaModel
end end
# Outer constructor method: # Outer constructor method:
function Models.ConformalModel(model; likelihood::Symbol=:classification_binary) function ConformalModel(model; likelihood::Symbol=:classification_binary)
Models.ConformalModel(model, likelihood) ConformalModel(model, likelihood)
end end
# Methods # Methods
function logits(M::Models.ConformalModel, X::AbstractArray) function logits(M::ConformalModel, X::AbstractArray)
return SliceMap.slicemap(x -> M.model(x), X, dims=(1, 2)) return SliceMap.slicemap(x -> M.model(x), X, dims=(1, 2))
end end
function probs(M::Models.ConformalModel, X::AbstractArray) function probs(M::ConformalModel, X::AbstractArray)
if M.likelihood == :classification_binary if M.likelihood == :classification_binary
output = σ.(logits(M, X)) output = σ.(logits(M, X))
elseif M.likelihood == :classification_multi elseif M.likelihood == :classification_multi
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment