From b607e47850dd946c8c13a720e431285183576b44 Mon Sep 17 00:00:00 2001 From: pat-alt <altmeyerpat@gmail.com> Date: Tue, 14 Feb 2023 08:17:19 +0100 Subject: [PATCH] uh --- Manifest.toml | 2 +- Project.toml | 1 + notebooks/conformal.qmd | 25 +++++++++++++++++++++++++ src/model.jl | 14 +++++++------- 4 files changed, 34 insertions(+), 8 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index f445e33d..895027ab 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.5" manifest_format = "2.0" -project_hash = "7dee1a21a38b8ebe5b8194926ba2e4b4df21760e" +project_hash = "a3f5e7cf561325560c244a222fe34c51dfd04226" [[deps.AbstractFFTs]] deps = ["ChainRulesCore", "LinearAlgebra"] diff --git a/Project.toml b/Project.toml index d6df8473..5a1cf95c 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f" CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" diff --git a/notebooks/conformal.qmd b/notebooks/conformal.qmd index e69de29b..015e7e52 100644 --- a/notebooks/conformal.qmd +++ b/notebooks/conformal.qmd @@ -0,0 +1,25 @@ + +```{julia} +using CCE +using ConformalPrediction +using CounterfactualExplanations +using CounterfactualExplanations.Data +using MLJFlux +``` + +```{julia} +counterfactual_data = load_linearly_separable() +``` + +```{julia} +clf = NeuralNetworkClassifier() +``` + + +```{julia} +X = table(permutedims(counterfactual_data.X)) +y = counterfactual_data.output_encoder.labels +conf_model = conformal_model(clf; method=:simple_inductive) +mach = machine(conf_model, X, y) +fit!(mach) +``` diff --git a/src/model.jl b/src/model.jl index 1a1d3d81..dc7e8a4a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -6,14 +6,14 @@ using SliceMap using Statistics """ - Models.ConformalModel <: AbstractDifferentiableJuliaModel + ConformalModel <: Models.AbstractDifferentiableJuliaModel Constructor for models trained in `Flux.jl`. """ -struct Models.ConformalModel <: AbstractDifferentiableJuliaModel +struct ConformalModel <: Models.AbstractDifferentiableJuliaModel model::ConformalPrediction.ConformalProbabilisticSet likelihood::Symbol - function FluxModel(model, likelihood) + function ConformalModel(model, likelihood) if likelihood ∈ [:classification_binary, :classification_multi] new(model, likelihood) else @@ -27,16 +27,16 @@ struct Models.ConformalModel <: AbstractDifferentiableJuliaModel end # Outer constructor method: -function Models.ConformalModel(model; likelihood::Symbol=:classification_binary) - Models.ConformalModel(model, likelihood) +function ConformalModel(model; likelihood::Symbol=:classification_binary) + ConformalModel(model, likelihood) end # Methods -function logits(M::Models.ConformalModel, X::AbstractArray) +function logits(M::ConformalModel, X::AbstractArray) return SliceMap.slicemap(x -> M.model(x), X, dims=(1, 2)) end -function probs(M::Models.ConformalModel, X::AbstractArray) +function probs(M::ConformalModel, X::AbstractArray) if M.likelihood == :classification_binary output = σ.(logits(M, X)) elseif M.likelihood == :classification_multi -- GitLab