Skip to content
Snippets Groups Projects
Commit b57c969f authored by pat-alt's avatar pat-alt
Browse files

penalty working

parent d5316c91
No related branches found
No related tags found
No related merge requests found
......@@ -20,25 +20,32 @@ using Plots
counterfactual_data = load_multi_class()
M = fit_model(counterfactual_data, :MLP)
target = 4
factual = 1
factual = 2
chosen = rand(findall(predict_label(M, counterfactual_data) .== factual))
x = select_factual(counterfactual_data, chosen)
```
# Search:
generator = GenericGenerator(opt=Descent(0.01))
ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
```
```{julia}
niter = 10000
niter = 100
nsamples = 100
plts = []
for target in ce.data.y_levels
# Search:
generator = GenericGenerator()
ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
sampler = CCE.EnergySampler(ce;niter=niter, nsamples=100)
Xgen = rand(sampler, nsamples)
plt = plot(M, counterfactual_data, target=ce.target, xlims=(-5,5),ylims=(-5,5),cbar=false)
scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=target,shape=:star,label="X|y=$target")
push!(plts, plt)
end
plot(plts..., layout=(1,length(ce.data.y_levels)), size=(length(ce.data.y_levels)*300,300))
sampler = CCE.EnergySampler(ce;niter=niter, nsamples=100)
Xgen = rand(sampler, nsamples)
plt = plot(M, counterfactual_data, target=ce.target, xlims=(-5,5),ylims=(-5,5),cbar=false)
scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=target,shape=:star,label="X|y=$target")
```
```{julia}
p1 = plot(ce)
```
```{julia}
@objective(generator, _ + 0.1distance_l2 + 100.0distance_from_energy)
ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
p2 = plot(ce)
```
\ No newline at end of file
using ChainRules: ignore_derivatives
using LinearAlgebra: norm
using Statistics: mean
"""
......@@ -33,11 +35,28 @@ function set_size_penalty(
end
function distance_from_energy(
counterfactual_explanation::AbstractCounterfactualExplanation;
n::Int=100, retrain=false, kwargs...
counterfactual_explanation::AbstractCounterfactualExplanation;
n::Int=100, retrain=false, agg=mean, kwargs...
)
sampler = get!(counterfactual_explanation.params, :energy_sampler) do
CCE.EnergySampler(counterfactual_explanation; kwargs...)
conditional_samples = []
ignore_derivatives() do
_dict = counterfactual_explanation.params
if !(:energy_sampler collect(keys(_dict)))
_dict[:energy_sampler] = CCE.EnergySampler(counterfactual_explanation; kwargs...)
end
sampler = _dict[:energy_sampler]
push!(conditional_samples, rand(sampler, n; retrain=retrain))
end
x′ = CounterfactualExplanations.counterfactual(counterfactual_explanation)
loss = map(eachslice(x′, dims=3)) do x
x = Matrix(x)
Δ = map(eachcol(conditional_samples[1])) do xsample
norm(x - xsample)
end
return mean(Δ)
end
conditional_samples =
loss = agg(loss)
return loss
end
\ No newline at end of file
......@@ -13,7 +13,7 @@ end
function EnergySampler(
ce::CounterfactualExplanation;
opt::JointEnergyModels.AbstractSamplingRule=SGLD(),
opt::JointEnergyModels.AbstractSamplingRule=ImproperSGLD(),
niter::Int=100,
nsamples::Int=1000
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment