diff --git a/notebooks/conformal.qmd b/notebooks/conformal.qmd
index 9c45239d3b594907d78abc306dbbcf901e96ba5b..8a84a6cd9dd4867228672e5351b8743152adc3c4 100644
--- a/notebooks/conformal.qmd
+++ b/notebooks/conformal.qmd
@@ -47,4 +47,8 @@ generator = CCE.ConformalGenerator()
 x = select_factual(counterfactual_data,rand(1:size(counterfactual_data.X,2)))
 y = predict_label(M, counterfactual_data, x)[1]
 target = counterfactual_data.y_levels[counterfactual_data.y_levels .!= y][1]
-```
\ No newline at end of file
+```
+
+```{julia}
+ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
+```
diff --git a/src/model.jl b/src/model.jl
index 1db725d1d3ea680cc31f098539b71ba351127dd5..654af645614f6fd74519dd63812648f7855ba684 100644
--- a/src/model.jl
+++ b/src/model.jl
@@ -59,18 +59,22 @@ function Models.logits(M::ConformalModel, X::AbstractArray)
     yhat = SliceMap.slicemap(X, dims=(1, 2)) do x
         conf_model = M.model
         fitresult = M.fitresult
-        x = MLJBase.table(permutedims(x))
-        p̂ = MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, x)...)
-        p̂ = map(p̂) do pp
-            L = p̂.decoder.classes
-            probas = pdf.(pp, L)
-            return probas
+        # x = MLJBase.table(permutedims(x))
+        # p̂ = MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, x)...)
+        # p̂ = map(p̂) do pp
+        #     L = p̂.decoder.classes
+        #     probas = pdf.(pp, L)
+        #     return probas
+        # end
+        p̂ = fitresult[1](x)
+        if size(p̂, 2) > 1
+            p̂ = reduce(hcat, p̂)
         end
-        p̂ = reduce(hcat, p̂)
         ŷ = reduce(hcat, (map(p -> log.(p) .+ log(sum(exp.(p))), eachcol(p̂))))
         if M.likelihood == :classification_binary
             ŷ = reduce(hcat, (map(y -> y[2] - y[1], eachcol(ŷ))))
         end
+        ŷ = ndims(ŷ) > 1 ? ŷ : permutedims([ŷ])
         return ŷ
     end
     return yhat