From 9149cc7f4378861d2befe817a4149f6ad155f811 Mon Sep 17 00:00:00 2001 From: pat-alt <altmeyerpat@gmail.com> Date: Sat, 18 Mar 2023 12:38:51 +0100 Subject: [PATCH] more work on notebook --- notebooks/fidelity.qmd | 6 +++++- src/sampling.jl | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/notebooks/fidelity.qmd b/notebooks/fidelity.qmd index e121e4df..e3453d57 100644 --- a/notebooks/fidelity.qmd +++ b/notebooks/fidelity.qmd @@ -212,7 +212,11 @@ counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactua plts = [] for name ∈ ordered_names ce = counterfactuals[name] - plt = plot(ce; title=name, colorbar=false, ticks = false, legend=false, zoom=0) + plt = plot(ce; title=name, colorbar=false, ticks = false, legend=false, zoom=-0.5) + if :energy_sampler ∈ collect(keys(ce.params)) + Xgen = ce.params[:energy_sampler].buffer + scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=target,shape=:star,label="X|y=$target") + end plts = vcat(plts..., plt) end _n = length(generators) diff --git a/src/sampling.jl b/src/sampling.jl index 229e4b69..4f57fc40 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -38,7 +38,7 @@ function EnergySampler( y::Any; opt::JointEnergyModels.AbstractSamplingRule=ImproperSGLD(), niter::Int=100, - nsamples::Int=1000 + nsamples::Int=100 ) @assert y ∈ data.y_levels || y ∈ 1:length(data.y_levels) @@ -64,6 +64,7 @@ Generates `n` samples from `EnergySampler` for conditioning value `y`. """ function generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100) X = e.sampler(e.model, e.opt, (size(e.data.X, 1), n); niter=niter, y=y) + X = X[:,map(x -> !any(isnan.(x)), eachcol(X))] return X end -- GitLab