From 4dfb424e268f2b5dc1851b524d33cc21307141d8 Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Wed, 13 Sep 2023 10:11:33 +0200
Subject: [PATCH] reconsidering sampling strat

---
 src/penalties.jl | 8 +++++++-
 src/sampling.jl  | 1 -
 2 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/src/penalties.jl b/src/penalties.jl
index 030a22e6..7c57a067 100644
--- a/src/penalties.jl
+++ b/src/penalties.jl
@@ -73,7 +73,13 @@ function energy_delta(
     # end
     conditional_samples = []
     ignore_derivatives() do 
-        xsampled = ECCCo.EnergySampler(ce; niter=niter, nsamples=ce.num_counterfactuals, kwargs...)
+        _dict = ce.params
+        if !(:energy_sampler ∈ collect(keys(_dict)))
+            _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
+        end
+        eng_sampler = _dict[:energy_sampler]
+        generate_samples!(eng_sampler, ce.num_counterfactuals, get_target_index(ce.data.y_levels, ce.target); niter=niter)
+        xsampled = eng_sampler.buffer[:,(end-ce.num_counterfactuals+1):end]
         push!(conditional_samples, xsampled)
     end
 
diff --git a/src/sampling.jl b/src/sampling.jl
index ca51f371..f62f26f8 100644
--- a/src/sampling.jl
+++ b/src/sampling.jl
@@ -88,7 +88,6 @@ Generates `n` samples from `EnergySampler` for conditioning value `y`.
 function generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100)
 
     # Generate samples:
-    # f = e.model.fitresult[1]
     f(x) = logits(e.model, x)
     rule = e.opt
     xsamples = e.sampler(f, rule; niter=niter, n_samples=n, y=y)
-- 
GitLab