From 4dfb424e268f2b5dc1851b524d33cc21307141d8 Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Wed, 13 Sep 2023 10:11:33 +0200 Subject: [PATCH] reconsidering sampling strat --- src/penalties.jl | 8 +++++++- src/sampling.jl | 1 - 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/penalties.jl b/src/penalties.jl index 030a22e6..7c57a067 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -73,7 +73,13 @@ function energy_delta( # end conditional_samples = [] ignore_derivatives() do - xsampled = ECCCo.EnergySampler(ce; niter=niter, nsamples=ce.num_counterfactuals, kwargs...) + _dict = ce.params + if !(:energy_sampler ∈ collect(keys(_dict))) + _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...) + end + eng_sampler = _dict[:energy_sampler] + generate_samples!(eng_sampler, ce.num_counterfactuals, get_target_index(ce.data.y_levels, ce.target); niter=niter) + xsampled = eng_sampler.buffer[:,(end-ce.num_counterfactuals+1):end] push!(conditional_samples, xsampled) end diff --git a/src/sampling.jl b/src/sampling.jl index ca51f371..f62f26f8 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -88,7 +88,6 @@ Generates `n` samples from `EnergySampler` for conditioning value `y`. function generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100) # Generate samples: - # f = e.model.fitresult[1] f(x) = logits(e.model, x) rule = e.opt xsamples = e.sampler(f, rule; niter=niter, n_samples=n, y=y) -- GitLab