diff --git a/notebooks/conformal.qmd b/notebooks/conformal.qmd
index 015e7e52c7431dbe684b21618f48a441025285c5..57b3f6d79a9c51f6134e9600f28c089d654be724 100644
--- a/notebooks/conformal.qmd
+++ b/notebooks/conformal.qmd
@@ -4,7 +4,10 @@ using CCE
 using ConformalPrediction
 using CounterfactualExplanations
 using CounterfactualExplanations.Data
+using Flux
+using MLJBase
 using MLJFlux
+using Plots
 ```
 
 ```{julia}
@@ -12,7 +15,11 @@ counterfactual_data = load_linearly_separable()
 ```
 
 ```{julia}
-clf = NeuralNetworkClassifier()
+builder = MLJFlux.@builder Chain(
+    Dense(n_in, 32, relu),
+    Dense(32, n_out)
+)
+clf = NeuralNetworkClassifier(builder=builder, epochs=100)
 ```
 
 
@@ -23,3 +30,8 @@ conf_model = conformal_model(clf; method=:simple_inductive)
 mach = machine(conf_model, X, y)
 fit!(mach)
 ```
+
+```{julia}
+M = CCE.ConformalModel(conf_model, mach.fitresult)
+```
+
diff --git a/src/CCE.jl b/src/CCE.jl
index 58f70b9cdc1699e50aa93f9a2a413a221ed405e5..e0fe38fa371838c1d18b50807ab0dd3004894e45 100644
--- a/src/CCE.jl
+++ b/src/CCE.jl
@@ -1,5 +1,7 @@
 module CCE
 
+import MLJModelInterface as MMI
+
 include("model.jl")
 include("ConformalGenerator.jl")
 
diff --git a/src/ConformalGenerator.jl b/src/ConformalGenerator.jl
index c772b69a64248905d8531fabf67383f26af41ed9..540d71a6b64d446bcbdbb219cdca1aaeaee89a80 100644
--- a/src/ConformalGenerator.jl
+++ b/src/ConformalGenerator.jl
@@ -48,6 +48,31 @@ function ConformalGenerator(;
     ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ)
 end
 
+# Loss:
+# """
+#     â„“(generator::ConformalGenerator, counterfactual_explanation::AbstractCounterfactualExplanation)
+
+# The default method to apply the generator loss function to the current counterfactual state for any generator.
+# """
+# function â„“(
+#     generator::ConformalGenerator,
+#     counterfactual_explanation::AbstractCounterfactualExplanation,
+# )
+
+#     loss_fun =
+#         !isnothing(generator.loss) ? getfield(Losses, generator.loss) :
+#         CounterfactualExplanations.guess_loss(counterfactual_explanation)
+#     @assert !isnothing(loss_fun) "No loss function provided and loss function could not be guessed based on model."
+#     loss = loss_fun(
+#         getfield(Models, :logits)(
+#             counterfactual_explanation.M,
+#             CounterfactualExplanations.decode_state(counterfactual_explanation),
+#         ),
+#         counterfactual_explanation.target_encoded,
+#     )
+#     return loss
+# end
+
 """
     set_size_penalty(
         generator::ConformalGenerator,
@@ -61,21 +86,7 @@ function set_size_penalty(
     counterfactual_explanation::AbstractCounterfactualExplanation,
 )
 
-    x_ = CounterfactualExplanations.decode_state(counterfactual_explanation)
-    M = counterfactual_explanation.M
-    model = isa(M.model, Vector) ? M.model : [M.model]
-    y_ = counterfactual_explanation.target_encoded
-
-    if M.likelihood == :classification_binary
-        loss_type = :logitbinarycrossentropy
-    else
-        loss_type = :logitcrossentropy
-    end
-
-    loss(x, y) =
-        sum([getfield(Flux.Losses, loss_type)(nn(x), y) for nn in model]) / length(model)
 
-    return loss(x_, y_)
 end
 
 # Complexity:
diff --git a/src/model.jl b/src/model.jl
index dc7e8a4ac612433cba9946acafad36602234d3bd..2abc8fa4404c537c32c1ff496d8f7acae4397549 100644
--- a/src/model.jl
+++ b/src/model.jl
@@ -1,6 +1,7 @@
 using ConformalPrediction
 using CounterfactualExplanations.Models
 using Flux
+using MLJBase
 using MLUtils
 using SliceMap
 using Statistics
@@ -12,10 +13,11 @@ Constructor for models trained in `Flux.jl`.
 """
 struct ConformalModel <: Models.AbstractDifferentiableJuliaModel
     model::ConformalPrediction.ConformalProbabilisticSet
+    fitresult::Any
     likelihood::Symbol
-    function ConformalModel(model, likelihood)
+    function ConformalModel(model, fitresult, likelihood)
         if likelihood ∈ [:classification_binary, :classification_multi]
-            new(model, likelihood)
+            new(model, fitresult, likelihood)
         else
             throw(
                 ArgumentError(
@@ -27,20 +29,57 @@ struct ConformalModel <: Models.AbstractDifferentiableJuliaModel
 end
 
 # Outer constructor method:
-function ConformalModel(model; likelihood::Symbol=:classification_binary)
-    ConformalModel(model, likelihood)
+function ConformalModel(model, fitresult; likelihood::Symbol=:classification_binary)
+    ConformalModel(model, fitresult, likelihood)
 end
 
 # Methods
-function logits(M::ConformalModel, X::AbstractArray)
-    return SliceMap.slicemap(x -> M.model(x), X, dims=(1, 2))
+@doc raw"""
+    Models.logits(M::ConformalModel, X::AbstractArray)
+
+To keep things consistent with the architecture of `CounterfactualExplanations.jl`, this method computes logits $\beta_i x_i$ (i.e. the linear predictions) for a Conformal Classifier. By default, `MLJ.jl` and `ConformalPrediction.jl` return probabilistic predictions. To get the underlying logits, we invert the softmax function. 
+
+Let $\hat{p}_i$ denote the estimated softmax output for feature $i$. Then in the multi-class case the following formula can be applied:
+
+```math
+\beta_i x_i = \log (\hat{p}_i) + \log (\sum_i \exp(\hat{p}_i))
+```
+
+For a short derivation, see here: https://math.stackexchange.com/questions/2786600/invert-the-softmax-function. 
+
+In the binary case logits are fed through the sigmoid function instead of softmax, so we need to further adjust as follows,
+
+```math
+\beta x = \beta_1 x_1 - \beta_0 x_0
+```    
+
+which follows from the derivation here: https://stats.stackexchange.com/questions/233658/softmax-vs-sigmoid-function-in-logistic-classifier
+"""
+function Models.logits(M::ConformalModel, X::AbstractArray)
+    yhat = SliceMap.slicemap(X, dims=(1, 2)) do x
+        conf_model = M.model
+        fitresult = M.fitresult
+        X = MLJBase.table(permutedims(X))
+        p̂ = MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, X)...)
+        p̂ = map(p̂) do pp
+            L = p̂.decoder.classes
+            probas = pdf.(pp, L)
+            return probas
+        end
+        p̂ = reduce(hcat, p̂)
+        ŷ = reduce(hcat, (map(p -> log.(p) .+ log(sum(exp.(p))), eachcol(p̂))))
+        if M.likelihood == :classification_binary
+            p̂ = reduce(hcat, (map(y -> y[2] - y[1], eachcol(ŷ))))
+        end
+    end
+    return yhat
 end
 
-function probs(M::ConformalModel, X::AbstractArray)
+function Models.probs(M::ConformalModel, X::AbstractArray)
     if M.likelihood == :classification_binary
-        output = σ.(logits(M, X))
+        output = σ.(Models.logits(M, X))
     elseif M.likelihood == :classification_multi
-        output = softmax(logits(M, X))
+        output = softmax(Models.logits(M, X))
     end
     return output
 end
\ No newline at end of file