diff --git a/src/penalties.jl b/src/penalties.jl index 030a22e6b2bc1c05968f36fee1d81f5753e75a15..7c57a06768b6547c60d500bb81bec34b860975be 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -73,7 +73,13 @@ function energy_delta( # end conditional_samples = [] ignore_derivatives() do - xsampled = ECCCo.EnergySampler(ce; niter=niter, nsamples=ce.num_counterfactuals, kwargs...) + _dict = ce.params + if !(:energy_sampler ∈ collect(keys(_dict))) + _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...) + end + eng_sampler = _dict[:energy_sampler] + generate_samples!(eng_sampler, ce.num_counterfactuals, get_target_index(ce.data.y_levels, ce.target); niter=niter) + xsampled = eng_sampler.buffer[:,(end-ce.num_counterfactuals+1):end] push!(conditional_samples, xsampled) end diff --git a/src/sampling.jl b/src/sampling.jl index ca51f37152e366a6f31e1f4ce9bd3293e8bd6cac..f62f26f8680e22200ed9b572db57fc795b1fef44 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -88,7 +88,6 @@ Generates `n` samples from `EnergySampler` for conditioning value `y`. function generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100) # Generate samples: - # f = e.model.fitresult[1] f(x) = logits(e.model, x) rule = e.opt xsamples = e.sampler(f, rule; niter=niter, n_samples=n, y=y)