diff --git a/notebooks/conformal.qmd b/notebooks/conformal.qmd
index 015fbe0d03bf0a73a962ea731533f0efc4d19990..1dfdff9849fe6132a4c23876bee4bd30000d5c52 100644
--- a/notebooks/conformal.qmd
+++ b/notebooks/conformal.qmd
@@ -5,6 +5,7 @@ using CounterfactualExplanations
 using CounterfactualExplanations.Data
 using CounterfactualExplanations.Objectives
 using Flux
+using LinearAlgebra
 using MLJBase
 using MLJFlux
 using Plots
@@ -102,9 +103,9 @@ is a smooth size penalty for conformal classifiers introduced by @stutz2022learn
 #| echo: false
 #| label: fig-losses
 #| fig-cap: "Illustration of the smooth size loss and the configurable classification loss."
-
-p1 = contourf(mach.model, mach.fitresult, X, y; plot_set_loss=true, zoom=0)
-p2 = contourf(mach.model, mach.fitresult, X, y; plot_classification_loss=true, target=target, zoom=0)
+temp = 0.5
+p1 = contourf(mach.model, mach.fitresult, X, y; plot_set_loss=true, zoom=0, temp=temp)
+p2 = contourf(mach.model, mach.fitresult, X, y; plot_classification_loss=true, target=target, zoom=0, temp=temp, clim=nothing, loss_matrix=ones(2,2))
 plot(p1, p2, size=(800,320))
 ```
 
@@ -157,6 +158,7 @@ plot(plts..., size=(_n * img_size,1.05*img_size), layout=(1,_n))
 #| fig-cap: "Comparison of counterfactuals produced using different generators."
 
 opt = Descent(0.01)
+temp = 0.75
 ordered_names = [
     "Generic (γ=0.5)",
     "Generic (γ=0.9)",
@@ -169,8 +171,8 @@ loss_fun = Objectives.logitbinarycrossentropy
 generators = Dict(
     ordered_names[1] => GenericGenerator(opt = opt, decision_threshold=0.5),
     ordered_names[2] => GenericGenerator(opt = opt, decision_threshold=0.9),
-    ordered_names[3] => CCE.ConformalGenerator(opt=opt, λ=[0.1,1]),
-    ordered_names[4] => CCE.ConformalGenerator(opt=opt, λ=[0.1,10]),
+    ordered_names[3] => CCE.ConformalGenerator(opt=opt, λ=[0.1,1], temp=temp),
+    ordered_names[4] => CCE.ConformalGenerator(opt=opt, λ=[0.1,10], temp=temp),
 )
 
 counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen; initialization=:identity) for (name, gen) in generators])
diff --git a/src/ConformalGenerator.jl b/src/ConformalGenerator.jl
index 280e483c84c5025af68d0656ca524a7ee1c12585..b169399b47a65d4fe0fff8c35087bf68c990017a 100644
--- a/src/ConformalGenerator.jl
+++ b/src/ConformalGenerator.jl
@@ -63,18 +63,21 @@ A configurable classification loss function for Conformal Predictors.
 function conformal_training_loss(counterfactual_explanation::AbstractCounterfactualExplanation; kwargs...)
     conf_model = counterfactual_explanation.M.model
     fitresult = counterfactual_explanation.M.fitresult
+    generator = counterfactual_explanation.generator
+    temp = hasfield(typeof(generator), :temp) ? generator.temp : nothing
+    K = length(counterfactual_explanation.data.y_levels)
     X = CounterfactualExplanations.decode_state(counterfactual_explanation)
     y = counterfactual_explanation.target_encoded[:,:,1]
     if counterfactual_explanation.M.likelihood == :classification_binary
         y = binary_to_onehot(y)
     end
     y = permutedims(y)
-    generator = counterfactual_explanation.generator
     loss = SliceMap.slicemap(X, dims=(1, 2)) do x
-        x = Matrix(x)
+        x = Matrix(x) 
         ConformalPrediction.classification_loss(
             conf_model, fitresult, x, y;
-            temp=generator.temp
+            temp=temp,
+            loss_matrix=Float32.(ones(K,K))
         )
     end
     loss = mean(loss)