Skip to content
Snippets Groups Projects
Commit 4dfb424e authored by Pat Alt's avatar Pat Alt
Browse files

reconsidering sampling strat

parent 162f2715
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
...@@ -73,7 +73,13 @@ function energy_delta( ...@@ -73,7 +73,13 @@ function energy_delta(
# end # end
conditional_samples = [] conditional_samples = []
ignore_derivatives() do ignore_derivatives() do
xsampled = ECCCo.EnergySampler(ce; niter=niter, nsamples=ce.num_counterfactuals, kwargs...) _dict = ce.params
if !(:energy_sampler collect(keys(_dict)))
_dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
end
eng_sampler = _dict[:energy_sampler]
generate_samples!(eng_sampler, ce.num_counterfactuals, get_target_index(ce.data.y_levels, ce.target); niter=niter)
xsampled = eng_sampler.buffer[:,(end-ce.num_counterfactuals+1):end]
push!(conditional_samples, xsampled) push!(conditional_samples, xsampled)
end end
......
...@@ -88,7 +88,6 @@ Generates `n` samples from `EnergySampler` for conditioning value `y`. ...@@ -88,7 +88,6 @@ Generates `n` samples from `EnergySampler` for conditioning value `y`.
function generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100) function generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100)
# Generate samples: # Generate samples:
# f = e.model.fitresult[1]
f(x) = logits(e.model, x) f(x) = logits(e.model, x)
rule = e.opt rule = e.opt
xsamples = e.sampler(f, rule; niter=niter, n_samples=n, y=y) xsamples = e.sampler(f, rule; niter=niter, n_samples=n, y=y)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment