diff --git a/notebooks/fidelity.qmd b/notebooks/fidelity.qmd index bb499e48e01baf9bbc76c69e494b95498d2911d6..e121e4df4147b0c3cab9c839856140c27ac4f444 100644 --- a/notebooks/fidelity.qmd +++ b/notebooks/fidelity.qmd @@ -107,7 +107,7 @@ The right panel of @fig-losses shows the configurable classification loss in the temp = 0.5 p1 = contourf(mach.model, mach.fitresult, X, y; plot_set_loss=true, zoom=0, temp=temp) -p2 = contourf(mach.model, mach.fitresult, X, y; plot_classification_loss=true, target=target, zoom=0, temp=temp, clim=nothing, loss_matrix=ones(2,2)) +p2 = contourf(mach.model, mach.fitresult, X, y; plot_classification_loss=true, target=1, zoom=0, temp=temp, clim=nothing, loss_matrix=ones(2,2)) plot(p1, p2, size=(800,320)) ``` @@ -186,11 +186,14 @@ $$ {#eq-solution} #| label: fig-ce #| fig-cap: "Comparison of counterfactuals produced using different generators." +λ = 10.0 + opt = Descent(0.01) ordered_names = [ - "Generic (γ=0.5)", - "Conformal (λ₂=1)", - "Conformal (λ₂=10)" + "Distance x", + "Set Size", + "Distance X_θ|t", + "Distance X|t" ] loss_fun = Objectives.logitbinarycrossentropy generator = GenericGenerator(opt = opt) @@ -198,8 +201,9 @@ generator = GenericGenerator(opt = opt) # Generators: generators = Dict( ordered_names[1] => generator, - ordered_names[2] => deepcopy(generator) |> gen -> @objective(gen, _ + 0.1distance_l2 + 1.0set_size_penalty), - ordered_names[3] => deepcopy(generator) |> gen -> @objective(gen, _ + 0.1distance_l2 + 10.0set_size_penalty), + ordered_names[2] => deepcopy(generator) |> gen -> @objective(gen, _ + 0.1distance + 10.0set_size_penalty), + ordered_names[3] => deepcopy(generator) |> gen -> @objective(gen, _ + 0.1distance + 10.0distance_from_energy), + ordered_names[4] => deepcopy(generator) |> gen -> @objective(gen, _ + 0.1distance + 10.0distance_from_targets), ) counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen; initialization=:identity, converge_when=:generator_conditions, gradient_tol=1e-3) for (name, gen) in generators]) @@ -212,8 +216,7 @@ for name ∈ ordered_names plts = vcat(plts..., plt) end _n = length(generators) -img_size = 300 -plot(plts..., size=(_n * img_size,1.05*img_size), layout=(1,_n)) +plot(plts..., size=(_n * img_height,1.05*img_height), layout=(1,_n)) ``` ## Multi-Class diff --git a/src/model.jl b/src/model.jl index f12c3ae901aacdf5af14df686ede2f4aa540d2f1..80fb20eb8b50961341f5a3b1e706920fdd7e126c 100644 --- a/src/model.jl +++ b/src/model.jl @@ -82,6 +82,7 @@ function Models.logits(M::ConformalModel, X::AbstractArray) yhat = map(eachslice(X, dims=ndims(X))) do x predict_logits(fitresult, x) end + yhat = MLUtils.stack(yhat) else yhat = predict_logits(fitresult, X) end diff --git a/src/penalties.jl b/src/penalties.jl index 5f85af406725ad258610b65a54ce8ecd371f13c2..3f5751cb6dfb2c2a7f294df2332ac5edc5f6280f 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -36,7 +36,7 @@ end function distance_from_energy( counterfactual_explanation::AbstractCounterfactualExplanation; - n::Int=100, retrain=false, agg=mean, kwargs... + n::Int=100, from_buffer=true, agg=mean, kwargs... ) conditional_samples = [] ignore_derivatives() do @@ -45,7 +45,7 @@ function distance_from_energy( _dict[:energy_sampler] = CCE.EnergySampler(counterfactual_explanation; kwargs...) end sampler = _dict[:energy_sampler] - push!(conditional_samples, rand(sampler, n; retrain=retrain)) + push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer)) end x′ = CounterfactualExplanations.counterfactual(counterfactual_explanation) loss = map(eachslice(x′, dims=3)) do x