Skip to content
Snippets Groups Projects
Commit adf59a74 authored by pat-alt's avatar pat-alt
Browse files

slowly slowly

parent 55aebf0a
No related branches found
No related tags found
No related merge requests found
......@@ -31,7 +31,20 @@ mach = machine(conf_model, X, y)
fit!(mach)
```
```{julia}
contourf(mach.model, mach.fitresult, X, y; plot_set_size=true)
```
## Counterfactual Explanation
```{julia}
M = CCE.ConformalModel(conf_model, mach.fitresult)
generator = CCE.ConformalGenerator()
```
```{julia}
x = select_factual(counterfactual_data,rand(1:size(counterfactual_data.X,2)))
y = predict_label(M, counterfactual_data, x)[1]
target = counterfactual_data.y_levels[counterfactual_data.y_levels .!= y][1]
```
\ No newline at end of file
......@@ -3,6 +3,7 @@ using CounterfactualExplanations.Generators
using Flux
using LinearAlgebra
using Parameters
using SliceMap
using Statistics
mutable struct ConformalGenerator <: AbstractGradientBasedGenerator
......@@ -19,7 +20,7 @@ end
opt::Flux.Optimise.AbstractOptimiser = Descent()
τ::AbstractFloat = 1e-3
κ::Real = 1.0
Τ::Real = 0.5
temp::Real = 0.5
end
"""
......@@ -50,31 +51,6 @@ function ConformalGenerator(;
ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ)
end
# Loss:
# """
# ℓ(generator::ConformalGenerator, counterfactual_explanation::AbstractCounterfactualExplanation)
# The default method to apply the generator loss function to the current counterfactual state for any generator.
# """
# function ℓ(
# generator::ConformalGenerator,
# counterfactual_explanation::AbstractCounterfactualExplanation,
# )
# loss_fun =
# !isnothing(generator.loss) ? getfield(Losses, generator.loss) :
# CounterfactualExplanations.guess_loss(counterfactual_explanation)
# @assert !isnothing(loss_fun) "No loss function provided and loss function could not be guessed based on model."
# loss = loss_fun(
# getfield(Models, :logits)(
# counterfactual_explanation.M,
# CounterfactualExplanations.decode_state(counterfactual_explanation),
# ),
# counterfactual_explanation.target_encoded,
# )
# return loss
# end
"""
set_size_penalty(
generator::ConformalGenerator,
......@@ -88,6 +64,19 @@ function set_size_penalty(
counterfactual_explanation::AbstractCounterfactualExplanation,
)
conf_model = counterfactual_explanation.M.model
fitresult = counterfactual_explanation.M.fitresult
X = CounterfactualExplanations.decode_state(counterfactual_explanation)
loss = SliceMap.slicemap(X, dims=(1,2)) do x
ConformalPrediction.smooth_size_loss(
conf_model, fitresult, x;
κ = generator.κ,
temp = generator.temp
)
end
loss = mean(loss)
return loss
end
......@@ -109,17 +98,12 @@ function Generators.h(
)
# Euclidean norm of gradient:
in_target_domain = all(target_probs(counterfactual_explanation) .>= 0.5)
if in_target_domain
grad_norm = gradient_penalty(generator, counterfactual_explanation)
else
grad_norm = 0
end
Ω = set_size_penalty(generator, counterfactual_explanation)
if length(generator.λ) == 1
penalty = generator.λ * (dist_ .+ grad_norm)
penalty = generator.λ * (dist_ .+ Ω)
else
penalty = generator.λ[1] * dist_ .+ generator.λ[2] * grad_norm
penalty = generator.λ[1] * dist_ .+ generator.λ[2] * Ω
end
return penalty
end
......@@ -59,8 +59,8 @@ function Models.logits(M::ConformalModel, X::AbstractArray)
yhat = SliceMap.slicemap(X, dims=(1, 2)) do x
conf_model = M.model
fitresult = M.fitresult
X = MLJBase.table(permutedims(X))
= MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, X)...)
x = MLJBase.table(permutedims(x))
= MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, x)...)
= map() do pp
L = .decoder.classes
probas = pdf.(pp, L)
......@@ -69,8 +69,9 @@ function Models.logits(M::ConformalModel, X::AbstractArray)
= reduce(hcat, )
= reduce(hcat, (map(p -> log.(p) .+ log(sum(exp.(p))), eachcol())))
if M.likelihood == :classification_binary
p̂ = reduce(hcat, (map(y -> y[2] - y[1], eachcol())))
ŷ = reduce(hcat, (map(y -> y[2] - y[1], eachcol())))
end
return
end
return yhat
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment