From 9149cc7f4378861d2befe817a4149f6ad155f811 Mon Sep 17 00:00:00 2001
From: pat-alt <altmeyerpat@gmail.com>
Date: Sat, 18 Mar 2023 12:38:51 +0100
Subject: [PATCH] more work on notebook

---
 notebooks/fidelity.qmd | 6 +++++-
 src/sampling.jl        | 3 ++-
 2 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/notebooks/fidelity.qmd b/notebooks/fidelity.qmd
index e121e4df..e3453d57 100644
--- a/notebooks/fidelity.qmd
+++ b/notebooks/fidelity.qmd
@@ -212,7 +212,11 @@ counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactua
 plts = []
 for name ∈ ordered_names
     ce = counterfactuals[name]
-    plt = plot(ce; title=name, colorbar=false, ticks = false, legend=false, zoom=0)
+    plt = plot(ce; title=name, colorbar=false, ticks = false, legend=false, zoom=-0.5)
+    if :energy_sampler ∈ collect(keys(ce.params))
+        Xgen = ce.params[:energy_sampler].buffer
+        scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=target,shape=:star,label="X|y=$target")
+    end
     plts = vcat(plts..., plt)
 end
 _n = length(generators)
diff --git a/src/sampling.jl b/src/sampling.jl
index 229e4b69..4f57fc40 100644
--- a/src/sampling.jl
+++ b/src/sampling.jl
@@ -38,7 +38,7 @@ function EnergySampler(
     y::Any;
     opt::JointEnergyModels.AbstractSamplingRule=ImproperSGLD(),
     niter::Int=100,
-    nsamples::Int=1000
+    nsamples::Int=100
 )
 
     @assert y ∈ data.y_levels || y ∈ 1:length(data.y_levels)
@@ -64,6 +64,7 @@ Generates `n` samples from `EnergySampler` for conditioning value `y`.
 """
 function generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100)
     X = e.sampler(e.model, e.opt, (size(e.data.X, 1), n); niter=niter, y=y)
+    X = X[:,map(x -> !any(isnan.(x)), eachcol(X))]
     return X
 end
 
-- 
GitLab