diff --git a/Manifest.toml b/Manifest.toml index f445e33da81b4a0f39f5402beb88bcb6fec7654d..895027abfcdbe69cc3ec768bac44e3f7cce525a8 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.5" manifest_format = "2.0" -project_hash = "7dee1a21a38b8ebe5b8194926ba2e4b4df21760e" +project_hash = "a3f5e7cf561325560c244a222fe34c51dfd04226" [[deps.AbstractFFTs]] deps = ["ChainRulesCore", "LinearAlgebra"] diff --git a/Project.toml b/Project.toml index d6df84732c414063cc6b317e040cc4929c292518..5a1cf95cc1483a07831447524b54b82249b33491 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f" CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" diff --git a/notebooks/conformal.qmd b/notebooks/conformal.qmd index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..015e7e52c7431dbe684b21618f48a441025285c5 100644 --- a/notebooks/conformal.qmd +++ b/notebooks/conformal.qmd @@ -0,0 +1,25 @@ + +```{julia} +using CCE +using ConformalPrediction +using CounterfactualExplanations +using CounterfactualExplanations.Data +using MLJFlux +``` + +```{julia} +counterfactual_data = load_linearly_separable() +``` + +```{julia} +clf = NeuralNetworkClassifier() +``` + + +```{julia} +X = table(permutedims(counterfactual_data.X)) +y = counterfactual_data.output_encoder.labels +conf_model = conformal_model(clf; method=:simple_inductive) +mach = machine(conf_model, X, y) +fit!(mach) +``` diff --git a/src/model.jl b/src/model.jl index 1a1d3d815de55ef4bb6f14d8ff6d8c4f328623cf..dc7e8a4ac612433cba9946acafad36602234d3bd 100644 --- a/src/model.jl +++ b/src/model.jl @@ -6,14 +6,14 @@ using SliceMap using Statistics """ - Models.ConformalModel <: AbstractDifferentiableJuliaModel + ConformalModel <: Models.AbstractDifferentiableJuliaModel Constructor for models trained in `Flux.jl`. """ -struct Models.ConformalModel <: AbstractDifferentiableJuliaModel +struct ConformalModel <: Models.AbstractDifferentiableJuliaModel model::ConformalPrediction.ConformalProbabilisticSet likelihood::Symbol - function FluxModel(model, likelihood) + function ConformalModel(model, likelihood) if likelihood ∈ [:classification_binary, :classification_multi] new(model, likelihood) else @@ -27,16 +27,16 @@ struct Models.ConformalModel <: AbstractDifferentiableJuliaModel end # Outer constructor method: -function Models.ConformalModel(model; likelihood::Symbol=:classification_binary) - Models.ConformalModel(model, likelihood) +function ConformalModel(model; likelihood::Symbol=:classification_binary) + ConformalModel(model, likelihood) end # Methods -function logits(M::Models.ConformalModel, X::AbstractArray) +function logits(M::ConformalModel, X::AbstractArray) return SliceMap.slicemap(x -> M.model(x), X, dims=(1, 2)) end -function probs(M::Models.ConformalModel, X::AbstractArray) +function probs(M::ConformalModel, X::AbstractArray) if M.likelihood == :classification_binary output = σ.(logits(M, X)) elseif M.likelihood == :classification_multi