Skip to content
Snippets Groups Projects
Commit 9149cc7f authored by pat-alt's avatar pat-alt
Browse files

more work on notebook

parent 2a482afc
No related branches found
No related tags found
No related merge requests found
......@@ -212,7 +212,11 @@ counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactua
plts = []
for name ∈ ordered_names
ce = counterfactuals[name]
plt = plot(ce; title=name, colorbar=false, ticks = false, legend=false, zoom=0)
plt = plot(ce; title=name, colorbar=false, ticks = false, legend=false, zoom=-0.5)
if :energy_sampler ∈ collect(keys(ce.params))
Xgen = ce.params[:energy_sampler].buffer
scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=target,shape=:star,label="X|y=$target")
end
plts = vcat(plts..., plt)
end
_n = length(generators)
......
......@@ -38,7 +38,7 @@ function EnergySampler(
y::Any;
opt::JointEnergyModels.AbstractSamplingRule=ImproperSGLD(),
niter::Int=100,
nsamples::Int=1000
nsamples::Int=100
)
@assert y data.y_levels || y 1:length(data.y_levels)
......@@ -64,6 +64,7 @@ Generates `n` samples from `EnergySampler` for conditioning value `y`.
"""
function generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100)
X = e.sampler(e.model, e.opt, (size(e.data.X, 1), n); niter=niter, y=y)
X = X[:,map(x -> !any(isnan.(x)), eachcol(X))]
return X
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment