Skip to content
Snippets Groups Projects
Commit b7a78dfa authored by pat-alt's avatar pat-alt
Browse files

basic set size loss penalty now for working counterfactual search

parent 147234e7
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,7 @@
julia_version = "1.8.5"
manifest_format = "2.0"
project_hash = "3488ba7b86c20f2db302dbcbc95e3f6227acdebb"
project_hash = "b1b5b9a95ae3c44dd7fa2c6d1d451e9bca8b9297"
[[deps.AbstractFFTs]]
deps = ["ChainRulesCore", "LinearAlgebra"]
......@@ -260,7 +260,7 @@ version = "0.6.2"
deps = ["CSV", "CUDA", "CategoricalArrays", "DataFrames", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "Parameters", "PkgTemplates", "Plots", "ProgressMeter", "Random", "Serialization", "SliceMap", "Statistics", "StatsBase", "Tables", "UMAP"]
path = "../CounterfactualExplanations.jl"
uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
version = "0.1.5"
version = "0.1.6"
[[deps.Crayons]]
git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15"
......
......@@ -5,6 +5,7 @@ version = "0.1.0"
[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f"
CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
......
......@@ -35,7 +35,6 @@ fit!(mach)
```{julia}
M = CCE.ConformalModel(conf_model, mach.fitresult)
generator = CCE.ConformalGenerator()
```
```{julia}
......@@ -51,5 +50,24 @@ plot(p1, p2, size=(800,320))
```
```{julia}
ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
```
#| output: true
#| echo: false
# Generators:
generators = Dict(
"Generic (γ=0.5)" => GenericGenerator(opt = opt, decision_threshold=0.5),
"Generic (γ=0.9)" => GenericGenerator(opt = opt, decision_threshold=0.9),
"Conformal (λ₂=1)" => CCE.ConformalGenerator(opt=opt, λ=[0.1,1]),
"Conformal (λ₂=10)" => CCE.ConformalGenerator(opt=opt, λ=[0.1,10]),
)
counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen;) for (name, gen) in generators])
# Plots:
plts = []
for (name,ce) ∈ counterfactuals
plt = plot(ce; title=name, colorbar=false, ticks = false, legend=false, zoom=0)
plts = vcat(plts..., plt)
end
_n = length(generators)
plot(plts..., size=(_n * 185,200), layout=(1,_n))
```
\ No newline at end of file
......@@ -13,6 +13,8 @@ mutable struct ConformalGenerator <: AbstractGradientBasedGenerator
decision_threshold::Union{Nothing,AbstractFloat}
opt::Flux.Optimise.AbstractOptimiser # optimizer
τ::AbstractFloat # tolerance for convergence
κ::Real
temp::Real
end
# API streamlining:
......@@ -48,7 +50,7 @@ function ConformalGenerator(;
kwargs...,
)
params = ConformalGeneratorParams(; kwargs...)
ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ)
ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ, params.κ, params.temp)
end
"""
......@@ -68,6 +70,7 @@ function set_size_penalty(
fitresult = counterfactual_explanation.M.fitresult
X = CounterfactualExplanations.decode_state(counterfactual_explanation)
loss = SliceMap.slicemap(X, dims=(1,2)) do x
x = Matrix(x)
ConformalPrediction.smooth_size_loss(
conf_model, fitresult, x;
κ = generator.κ,
......@@ -97,13 +100,19 @@ function Generators.h(
CounterfactualExplanations.decode_state(counterfactual_explanation),
)
# Euclidean norm of gradient:
Ω = set_size_penalty(generator, counterfactual_explanation)
# Set size penalty:
in_target_domain = all(target_probs(counterfactual_explanation) .>= 0.5)
if in_target_domain
Ω = set_size_penalty(generator, counterfactual_explanation)
else
Ω = 0
end
if length(generator.λ) == 1
penalty = generator.λ * (dist_ .+ Ω)
else
penalty = generator.λ[1] * dist_ .+ generator.λ[2] * Ω
end
return penalty
end
......@@ -67,9 +67,10 @@ function Models.logits(M::ConformalModel, X::AbstractArray)
# return probas
# end
= fitresult[1](x)
if size(, 2) > 1
= reduce(hcat, )
if ndims() == 2
= []
end
= reduce(hcat, )
= reduce(hcat, (map(p -> log.(p) .+ log(sum(exp.(p))), eachcol())))
if M.likelihood == :classification_binary
= reduce(hcat, (map(y -> y[2] - y[1], eachcol())))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment