diff --git a/src/penalties.jl b/src/penalties.jl
index 030a22e6b2bc1c05968f36fee1d81f5753e75a15..7c57a06768b6547c60d500bb81bec34b860975be 100644
--- a/src/penalties.jl
+++ b/src/penalties.jl
@@ -73,7 +73,13 @@ function energy_delta(
     # end
     conditional_samples = []
     ignore_derivatives() do 
-        xsampled = ECCCo.EnergySampler(ce; niter=niter, nsamples=ce.num_counterfactuals, kwargs...)
+        _dict = ce.params
+        if !(:energy_sampler ∈ collect(keys(_dict)))
+            _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
+        end
+        eng_sampler = _dict[:energy_sampler]
+        generate_samples!(eng_sampler, ce.num_counterfactuals, get_target_index(ce.data.y_levels, ce.target); niter=niter)
+        xsampled = eng_sampler.buffer[:,(end-ce.num_counterfactuals+1):end]
         push!(conditional_samples, xsampled)
     end
 
diff --git a/src/sampling.jl b/src/sampling.jl
index ca51f37152e366a6f31e1f4ce9bd3293e8bd6cac..f62f26f8680e22200ed9b572db57fc795b1fef44 100644
--- a/src/sampling.jl
+++ b/src/sampling.jl
@@ -88,7 +88,6 @@ Generates `n` samples from `EnergySampler` for conditioning value `y`.
 function generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100)
 
     # Generate samples:
-    # f = e.model.fitresult[1]
     f(x) = logits(e.model, x)
     rule = e.opt
     xsamples = e.sampler(f, rule; niter=niter, n_samples=n, y=y)