diff --git a/notebooks/conformal.qmd b/notebooks/conformal.qmd
index 57b3f6d79a9c51f6134e9600f28c089d654be724..9c45239d3b594907d78abc306dbbcf901e96ba5b 100644
--- a/notebooks/conformal.qmd
+++ b/notebooks/conformal.qmd
@@ -31,7 +31,20 @@ mach = machine(conf_model, X, y)
 fit!(mach)
 ```
 
+```{julia}
+contourf(mach.model, mach.fitresult, X, y; plot_set_size=true)
+```
+
+## Counterfactual Explanation
+
+
 ```{julia}
 M = CCE.ConformalModel(conf_model, mach.fitresult)
+generator = CCE.ConformalGenerator()
 ```
 
+```{julia}
+x = select_factual(counterfactual_data,rand(1:size(counterfactual_data.X,2)))
+y = predict_label(M, counterfactual_data, x)[1]
+target = counterfactual_data.y_levels[counterfactual_data.y_levels .!= y][1]
+```
\ No newline at end of file
diff --git a/src/ConformalGenerator.jl b/src/ConformalGenerator.jl
index bde032bd2aea466c0b1729d09823b37755b60ab1..b86bccc6a85f4f6e217e8b9c5f25c89ca90d93f1 100644
--- a/src/ConformalGenerator.jl
+++ b/src/ConformalGenerator.jl
@@ -3,6 +3,7 @@ using CounterfactualExplanations.Generators
 using Flux
 using LinearAlgebra
 using Parameters
+using SliceMap
 using Statistics
 
 mutable struct ConformalGenerator <: AbstractGradientBasedGenerator
@@ -19,7 +20,7 @@ end
     opt::Flux.Optimise.AbstractOptimiser = Descent()
     Ï„::AbstractFloat = 1e-3
     κ::Real = 1.0
-    Τ::Real = 0.5
+    temp::Real = 0.5
 end
 
 """
@@ -50,31 +51,6 @@ function ConformalGenerator(;
     ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ)
 end
 
-# Loss:
-# """
-#     â„“(generator::ConformalGenerator, counterfactual_explanation::AbstractCounterfactualExplanation)
-
-# The default method to apply the generator loss function to the current counterfactual state for any generator.
-# """
-# function â„“(
-#     generator::ConformalGenerator,
-#     counterfactual_explanation::AbstractCounterfactualExplanation,
-# )
-
-#     loss_fun =
-#         !isnothing(generator.loss) ? getfield(Losses, generator.loss) :
-#         CounterfactualExplanations.guess_loss(counterfactual_explanation)
-#     @assert !isnothing(loss_fun) "No loss function provided and loss function could not be guessed based on model."
-#     loss = loss_fun(
-#         getfield(Models, :logits)(
-#             counterfactual_explanation.M,
-#             CounterfactualExplanations.decode_state(counterfactual_explanation),
-#         ),
-#         counterfactual_explanation.target_encoded,
-#     )
-#     return loss
-# end
-
 """
     set_size_penalty(
         generator::ConformalGenerator,
@@ -88,6 +64,19 @@ function set_size_penalty(
     counterfactual_explanation::AbstractCounterfactualExplanation,
 )
 
+    conf_model = counterfactual_explanation.M.model
+    fitresult = counterfactual_explanation.M.fitresult
+    X = CounterfactualExplanations.decode_state(counterfactual_explanation)
+    loss = SliceMap.slicemap(X, dims=(1,2)) do x
+        ConformalPrediction.smooth_size_loss(
+            conf_model, fitresult, x;
+            κ = generator.κ,
+            temp = generator.temp
+        )
+    end
+    loss = mean(loss)
+
+    return loss
 
 end
 
@@ -109,17 +98,12 @@ function Generators.h(
     )
 
     # Euclidean norm of gradient:
-    in_target_domain = all(target_probs(counterfactual_explanation) .>= 0.5)
-    if in_target_domain
-        grad_norm = gradient_penalty(generator, counterfactual_explanation)
-    else
-        grad_norm = 0
-    end
+    Ω = set_size_penalty(generator, counterfactual_explanation)
 
     if length(generator.λ) == 1
-        penalty = generator.λ * (dist_ .+ grad_norm)
+        penalty = generator.λ * (dist_ .+ Ω)
     else
-        penalty = generator.λ[1] * dist_ .+ generator.λ[2] * grad_norm
+        penalty = generator.λ[1] * dist_ .+ generator.λ[2] * Ω
     end
     return penalty
 end
diff --git a/src/model.jl b/src/model.jl
index 2abc8fa4404c537c32c1ff496d8f7acae4397549..1db725d1d3ea680cc31f098539b71ba351127dd5 100644
--- a/src/model.jl
+++ b/src/model.jl
@@ -59,8 +59,8 @@ function Models.logits(M::ConformalModel, X::AbstractArray)
     yhat = SliceMap.slicemap(X, dims=(1, 2)) do x
         conf_model = M.model
         fitresult = M.fitresult
-        X = MLJBase.table(permutedims(X))
-        p̂ = MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, X)...)
+        x = MLJBase.table(permutedims(x))
+        p̂ = MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, x)...)
         p̂ = map(p̂) do pp
             L = p̂.decoder.classes
             probas = pdf.(pp, L)
@@ -69,8 +69,9 @@ function Models.logits(M::ConformalModel, X::AbstractArray)
         p̂ = reduce(hcat, p̂)
         ŷ = reduce(hcat, (map(p -> log.(p) .+ log(sum(exp.(p))), eachcol(p̂))))
         if M.likelihood == :classification_binary
-            p̂ = reduce(hcat, (map(y -> y[2] - y[1], eachcol(ŷ))))
+            ŷ = reduce(hcat, (map(y -> y[2] - y[1], eachcol(ŷ))))
         end
+        return ŷ
     end
     return yhat
 end