Skip to content
Snippets Groups Projects
Commit 2a482afc authored by pat-alt's avatar pat-alt
Browse files

uh

parent 4f0c8894
No related branches found
No related tags found
No related merge requests found
......@@ -107,7 +107,7 @@ The right panel of @fig-losses shows the configurable classification loss in the
temp = 0.5
p1 = contourf(mach.model, mach.fitresult, X, y; plot_set_loss=true, zoom=0, temp=temp)
p2 = contourf(mach.model, mach.fitresult, X, y; plot_classification_loss=true, target=target, zoom=0, temp=temp, clim=nothing, loss_matrix=ones(2,2))
p2 = contourf(mach.model, mach.fitresult, X, y; plot_classification_loss=true, target=1, zoom=0, temp=temp, clim=nothing, loss_matrix=ones(2,2))
plot(p1, p2, size=(800,320))
```
......@@ -186,11 +186,14 @@ $$ {#eq-solution}
#| label: fig-ce
#| fig-cap: "Comparison of counterfactuals produced using different generators."
λ = 10.0
opt = Descent(0.01)
ordered_names = [
"Generic (γ=0.5)",
"Conformal (λ₂=1)",
"Conformal (λ₂=10)"
"Distance x",
"Set Size",
"Distance X_θ|t",
"Distance X|t"
]
loss_fun = Objectives.logitbinarycrossentropy
generator = GenericGenerator(opt = opt)
......@@ -198,8 +201,9 @@ generator = GenericGenerator(opt = opt)
# Generators:
generators = Dict(
ordered_names[1] => generator,
ordered_names[2] => deepcopy(generator) |> gen -> @objective(gen, _ + 0.1distance_l2 + 1.0set_size_penalty),
ordered_names[3] => deepcopy(generator) |> gen -> @objective(gen, _ + 0.1distance_l2 + 10.0set_size_penalty),
ordered_names[2] => deepcopy(generator) |> gen -> @objective(gen, _ + 0.1distance + 10.0set_size_penalty),
ordered_names[3] => deepcopy(generator) |> gen -> @objective(gen, _ + 0.1distance + 10.0distance_from_energy),
ordered_names[4] => deepcopy(generator) |> gen -> @objective(gen, _ + 0.1distance + 10.0distance_from_targets),
)
counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen; initialization=:identity, converge_when=:generator_conditions, gradient_tol=1e-3) for (name, gen) in generators])
......@@ -212,8 +216,7 @@ for name ∈ ordered_names
plts = vcat(plts..., plt)
end
_n = length(generators)
img_size = 300
plot(plts..., size=(_n * img_size,1.05*img_size), layout=(1,_n))
plot(plts..., size=(_n * img_height,1.05*img_height), layout=(1,_n))
```
## Multi-Class
......
......@@ -82,6 +82,7 @@ function Models.logits(M::ConformalModel, X::AbstractArray)
yhat = map(eachslice(X, dims=ndims(X))) do x
predict_logits(fitresult, x)
end
yhat = MLUtils.stack(yhat)
else
yhat = predict_logits(fitresult, X)
end
......
......@@ -36,7 +36,7 @@ end
function distance_from_energy(
counterfactual_explanation::AbstractCounterfactualExplanation;
n::Int=100, retrain=false, agg=mean, kwargs...
n::Int=100, from_buffer=true, agg=mean, kwargs...
)
conditional_samples = []
ignore_derivatives() do
......@@ -45,7 +45,7 @@ function distance_from_energy(
_dict[:energy_sampler] = CCE.EnergySampler(counterfactual_explanation; kwargs...)
end
sampler = _dict[:energy_sampler]
push!(conditional_samples, rand(sampler, n; retrain=retrain))
push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer))
end
x′ = CounterfactualExplanations.counterfactual(counterfactual_explanation)
loss = map(eachslice(x′, dims=3)) do x
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment