diff --git a/Manifest.toml b/Manifest.toml index fea68c34e08a641a939e842391d36fc86d36d563..ec1a4d53558645c9348900c816cedd9ca74b12c3 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.5" manifest_format = "2.0" -project_hash = "3488ba7b86c20f2db302dbcbc95e3f6227acdebb" +project_hash = "b1b5b9a95ae3c44dd7fa2c6d1d451e9bca8b9297" [[deps.AbstractFFTs]] deps = ["ChainRulesCore", "LinearAlgebra"] @@ -260,7 +260,7 @@ version = "0.6.2" deps = ["CSV", "CUDA", "CategoricalArrays", "DataFrames", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "Parameters", "PkgTemplates", "Plots", "ProgressMeter", "Random", "Serialization", "SliceMap", "Statistics", "StatsBase", "Tables", "UMAP"] path = "../CounterfactualExplanations.jl" uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" -version = "0.1.5" +version = "0.1.6" [[deps.Crayons]] git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" diff --git a/Project.toml b/Project.toml index 2ff890fd65823b1232424beca23a8c54351bf12a..83402ab8101c65057de4caaffc3e3d0187484bb4 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.1.0" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f" CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" diff --git a/notebooks/conformal.qmd b/notebooks/conformal.qmd index 5dcef400219e44811e5253e71aace478997e38ef..1cea70166e5e905840bb03f7b80f190065900633 100644 --- a/notebooks/conformal.qmd +++ b/notebooks/conformal.qmd @@ -35,7 +35,6 @@ fit!(mach) ```{julia} M = CCE.ConformalModel(conf_model, mach.fitresult) -generator = CCE.ConformalGenerator() ``` ```{julia} @@ -51,5 +50,24 @@ plot(p1, p2, size=(800,320)) ``` ```{julia} -ce = generate_counterfactual(x, target, counterfactual_data, M, generator) -``` +#| output: true +#| echo: false +# Generators: +generators = Dict( + "Generic (γ=0.5)" => GenericGenerator(opt = opt, decision_threshold=0.5), + "Generic (γ=0.9)" => GenericGenerator(opt = opt, decision_threshold=0.9), + "Conformal (λ₂=1)" => CCE.ConformalGenerator(opt=opt, λ=[0.1,1]), + "Conformal (λ₂=10)" => CCE.ConformalGenerator(opt=opt, λ=[0.1,10]), +) + + +counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen;) for (name, gen) in generators]) +# Plots: +plts = [] +for (name,ce) ∈ counterfactuals + plt = plot(ce; title=name, colorbar=false, ticks = false, legend=false, zoom=0) + plts = vcat(plts..., plt) +end +_n = length(generators) +plot(plts..., size=(_n * 185,200), layout=(1,_n)) +``` \ No newline at end of file diff --git a/src/ConformalGenerator.jl b/src/ConformalGenerator.jl index b86bccc6a85f4f6e217e8b9c5f25c89ca90d93f1..133a3362779d6df36ad42010199339a61190e895 100644 --- a/src/ConformalGenerator.jl +++ b/src/ConformalGenerator.jl @@ -13,6 +13,8 @@ mutable struct ConformalGenerator <: AbstractGradientBasedGenerator decision_threshold::Union{Nothing,AbstractFloat} opt::Flux.Optimise.AbstractOptimiser # optimizer τ::AbstractFloat # tolerance for convergence + κ::Real + temp::Real end # API streamlining: @@ -48,7 +50,7 @@ function ConformalGenerator(; kwargs..., ) params = ConformalGeneratorParams(; kwargs...) - ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ) + ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ, params.κ, params.temp) end """ @@ -68,6 +70,7 @@ function set_size_penalty( fitresult = counterfactual_explanation.M.fitresult X = CounterfactualExplanations.decode_state(counterfactual_explanation) loss = SliceMap.slicemap(X, dims=(1,2)) do x + x = Matrix(x) ConformalPrediction.smooth_size_loss( conf_model, fitresult, x; κ = generator.κ, @@ -97,13 +100,19 @@ function Generators.h( CounterfactualExplanations.decode_state(counterfactual_explanation), ) - # Euclidean norm of gradient: - Ω = set_size_penalty(generator, counterfactual_explanation) - + # Set size penalty: + in_target_domain = all(target_probs(counterfactual_explanation) .>= 0.5) + if in_target_domain + Ω = set_size_penalty(generator, counterfactual_explanation) + else + Ω = 0 + end + if length(generator.λ) == 1 penalty = generator.λ * (dist_ .+ Ω) else penalty = generator.λ[1] * dist_ .+ generator.λ[2] * Ω end + return penalty end diff --git a/src/model.jl b/src/model.jl index 654af645614f6fd74519dd63812648f7855ba684..eddd8b3064b42f56a2dabd95a21e370ead1b7ac3 100644 --- a/src/model.jl +++ b/src/model.jl @@ -67,9 +67,10 @@ function Models.logits(M::ConformalModel, X::AbstractArray) # return probas # end p̂ = fitresult[1](x) - if size(p̂, 2) > 1 - p̂ = reduce(hcat, p̂) + if ndims(p̂) == 2 + p̂ = [p̂] end + p̂ = reduce(hcat, p̂) ŷ = reduce(hcat, (map(p -> log.(p) .+ log(sum(exp.(p))), eachcol(p̂)))) if M.likelihood == :classification_binary ŷ = reduce(hcat, (map(y -> y[2] - y[1], eachcol(ŷ))))