Skip to content
Snippets Groups Projects
Commit 0c057ef9 authored by pat-alt's avatar pat-alt
Browse files

more on sampling

parent 3c5f449d
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,7 @@
julia_version = "1.8.5"
manifest_format = "2.0"
project_hash = "c05645661b945bea85f24ae44163718eb7f430c7"
project_hash = "d4865488d9d741cc2d833e66a476d21fceebe412"
[[deps.AbstractFFTs]]
deps = ["ChainRulesCore", "LinearAlgebra"]
......@@ -275,7 +275,7 @@ uuid = "ed09eef8-17a6-5b46-8889-db040fac31e3"
version = "0.3.2"
[[deps.ConformalPrediction]]
deps = ["CategoricalArrays", "MLJBase", "MLJModelInterface", "NaturalSort", "Plots", "StatsBase"]
deps = ["CategoricalArrays", "ChainRules", "Flux", "LinearAlgebra", "MLJBase", "MLJModelInterface", "NaturalSort", "Plots", "StatsBase"]
path = "../ConformalPrediction.jl"
uuid = "98bfc277-1877-43dc-819b-a3e38c30242f"
version = "0.1.6"
......@@ -767,6 +767,12 @@ git-tree-sha1 = "84b10656a41ef564c39d2d477d7236966d2b5683"
uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
version = "1.12.0"
[[deps.JointEnergyModels]]
deps = ["Distributions", "Flux", "StatsBase"]
path = "../JointEnergyModels.jl"
uuid = "48c56d24-211d-4463-bbc0-7a701b291131"
version = "0.1.0"
[[deps.JpegTurbo_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "6f2675ef130a300a112286de91973805fcc5ffbc"
......
......@@ -9,7 +9,9 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f"
CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
JointEnergyModels = "48c56d24-211d-4463-bbc0-7a701b291131"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
......
......@@ -120,21 +120,20 @@ plot(p1, p2, size=(800,320))
opt = Descent(0.01)
ordered_names = [
"Generic (γ=0.5)",
"Generic (γ=0.9)",
"Conformal (λ₂=1)",
"Conformal (λ₂=10)"
]
loss_fun = Objectives.logitbinarycrossentropy
generator = GenericGenerator(opt = opt)
# Generators:
generators = Dict(
ordered_names[1] => GenericGenerator(opt = opt, decision_threshold=0.5),
ordered_names[2] => GenericGenerator(opt = opt, decision_threshold=0.9),
ordered_names[3] => CCE.ConformalGenerator(loss=loss_fun, opt=opt, λ=[0.1,1]),
ordered_names[4] => CCE.ConformalGenerator(loss=loss_fun, opt=opt, λ=[0.1,10]),
ordered_names[1] => generator,
ordered_names[2] => deepcopy(generator) |> gen -> @objective(gen, _ + 0.1distance_l2 + 1.0set_size_penalty),
ordered_names[3] => deepcopy(generator) |> gen -> @objective(gen, _ + 0.1distance_l2 + 10.0set_size_penalty),
)
counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen; initialization=:identity) for (name, gen) in generators])
counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen; initialization=:identity, converge_when=:generator_conditions, gradient_tol=1e-3) for (name, gen) in generators])
# Plots:
plts = []
......
```{julia}
using CCE
using ConformalPrediction
using CounterfactualExplanations
using CounterfactualExplanations.Data
using CounterfactualExplanations.Objectives
using Distributions
using Flux
using JointEnergyModels
using LinearAlgebra
using MLJBase
using MLJFlux
using Plots
```
# Fidelity Measures
```{julia}
# Setup
counterfactual_data = load_multi_class()
M = fit_model(counterfactual_data, :MLP)
target = 4
factual = 1
chosen = rand(findall(predict_label(M, counterfactual_data) .== factual))
x = select_factual(counterfactual_data, chosen)
```
```{julia}
niter = 10
nsamples = 100
plts = []
for target in ce.data.y_levels
# Search:
generator = GenericGenerator()
ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
sampler = CCE.EnergySampler(ce;niter=niter, nsamples=100)
Xgen = rand(sampler, nsamples)
plt = plot(M, counterfactual_data, target=ce.target, xlims=(-5,5),ylims=(-5,5),cbar=false)
scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=target,shape=:star,label="X|y=$target")
push!(plts, plt)
end
plot(plts..., layout=(1,length(ce.data.y_levels)), size=(length(ce.data.y_levels)*300,300))
```
\ No newline at end of file
module CCE
using CounterfactualExplanations
import MLJModelInterface as MMI
include("model.jl")
include("ConformalGenerator.jl")
include("penalties.jl")
include("losses.jl")
include("generator.jl")
include("sampling.jl")
# include("ConformalGenerator.jl")
end
using Statistics: mean
"""
set_size_penalty(counterfactual_explanation::AbstractCounterfactualExplanation)
Penalty for smooth conformal set size.
"""
function set_size_penalty(
counterfactual_explanation::AbstractCounterfactualExplanation;
κ::Real=1.0, temp::Real=0.5, agg=mean
)
conf_model = counterfactual_explanation.M.model
fitresult = counterfactual_explanation.M.fitresult
X = CounterfactualExplanations.decode_state(counterfactual_explanation)
loss = map(eachslice(X, dims=3)) do x
x = Matrix(x)
if target_probs(counterfactual_explanation, x)[1] >= 0.5
l = ConformalPrediction.smooth_size_loss(
conf_model, fitresult, x;
κ=κ,
temp=temp
)[1]
else
l = 0.0
end
return l
end
loss = agg(loss)
return loss
end
function distance_from_energy(
counterfactual_explanation::AbstractCounterfactualExplanation;
n::Int=100, retrain=false, kwargs...
)
sampler = get!(counterfactual_explanation.params, :energy_sampler) do
CCE.EnergySampler(counterfactual_explanation; kwargs...)
end
conditional_samples =
end
\ No newline at end of file
using CounterfactualExplanations
using Distributions
using JointEnergyModels
(model::AbstractFittedModel)(x) = logits(model, x)
mutable struct EnergySampler
ce::CounterfactualExplanation
sampler::JointEnergyModels.ConditionalSampler
opt::JointEnergyModels.AbstractSamplingRule
buffer::AbstractArray
end
function EnergySampler(
ce::CounterfactualExplanation;
opt::JointEnergyModels.AbstractSamplingRule=ImproperSGLD(),
niter::Int=100,
nsamples::Int=1000
)
# Setup:
model = ce.M
data = ce.data
K = length(data.y_levels)
𝒟x = Normal()
𝒟y = Categorical(ones(K) ./ K)
sampler = ConditionalSampler(𝒟x, 𝒟y)
# Fit:
i = get_target_index(data.y_levels, ce.target)
buffer = sampler(model.model, opt, (size(data.X, 1), nsamples); niter=niter, y=i)
return EnergySampler(ce, sampler, opt, buffer)
end
function Base.rand(sampler::EnergySampler, n::Int=100; retrain=false)
ntotal = size(sampler.buffer,2)
idx = rand(1:ntotal, n)
if !retrain
X = sampler.buffer[:,idx]
end
return X
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment