Skip to content
Snippets Groups Projects
Commit ae9a78ac authored by pat-alt's avatar pat-alt
Browse files

plot reviewer 1

parent 21e694ac
No related branches found
No related tags found
1 merge request!8985 overshooting
......@@ -12,7 +12,7 @@ include("$(pwd())/experiments/setup_env.jl")
## Counterfactual Path - MNIST
```{julia}
outcome = Serialization.deserialize("results/linearly_separable_outcome.jls")
outcome = Serialization.deserialize("results/moons_outcome.jls")
data = outcome.exper.counterfactual_data
```
......@@ -23,16 +23,22 @@ target = data.y_levels[1]
```{julia}
n_samp = 100
n_rand = 500
alpha = 0.2
plts = []
for (mod_name, model) in outcome.model_dict
t = get_target_index(data.y_levels, target)
E(x) = -logits(model, x)[t, :]
x_samp = data.X[:,rand(findall(data.output_encoder.labels.==target),n_samp)]
x_rand = Float32.(randn(size(data.X,1), 1000))
x_rand = Float32.(randn(size(data.X,1), n_rand))
dist = map(eachcol(x_rand)) do x
mean(map(y -> norm(x.-y),eachcol(x_samp)))
end
plt = scatter(E(x_rand), dist, alpha=0.5, label="", title=mod_name)
plt = scatter(
E(x_rand), dist;
fillalpha=alpha, label="", title=mod_name, smooth=:true,
lc=:red, lw=2
)
push!(plts, plt)
end
plot(plts..., layout=(1,length(outcome.model_dict)), size=(1000,250))
......
Source diff could not be displayed: it is too large. Options to address this: view the blob.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment