poc.qmd 7.71 KiB
```{julia}
include("$(pwd())/notebooks/setup.jl")
eval(setup_notebooks)
```
```{julia}
# Counteractual data:
n_obs = 1000
counterfactual_data = load_blobs(n_obs; cluster_std=0.1, center_box=(-1. => 1.))
X = counterfactual_data.X
y = counterfactual_data.y
labels = counterfactual_data.output_encoder.labels
input_dim, nobs = size(X)
batch_size = Int(round(nobs/10))
epochs = 100
```
```{julia}
Plots.plot()
display(Plots.scatter!(counterfactual_data))
```
```{julia}
𝒟x = Normal()
𝒟y = Categorical(ones(2) ./ 2)
sampler = ConditionalSampler(𝒟x, 𝒟y, input_size=size(X)[1:end-1], batch_size=50)
n_hidden = 16
clf = JointEnergyClassifier(
sampler;
builder=MLJFlux.MLP(
hidden=(n_hidden, n_hidden, n_hidden),
σ=Flux.swish
),
batch_size=batch_size,
finaliser=Flux.softmax,
loss=Flux.Losses.crossentropy,
jem_training_params=(
α=[1.0,1.0,1e-1],
verbosity=10,
),
epochs=epochs,
sampling_steps=30,
)
```
```{julia}
method = :simple_inductive
cov = .95
conf_model = conformal_model(clf; method=method, coverage=cov)
mach = machine(conf_model, table(permutedims(X)), labels)
fit!(mach)
Serialization.serialize(joinpath(output_path,"poc_model.jls"), mach)
```
```{julia}
#| echo: false
niter = 1000
jem = mach.model.model.jem
batch_size = mach.model.model.batch_size
X = Float32.(matrix(X))
if typeof(jem.sampler) <: ConditionalSampler
plts = []
for target in 1:2
X̂ = generate_conditional_samples(jem, batch_size, target; niter=niter)
ex = extrema(hcat(X,X̂), dims=2)
xlims = ex[1]
ylims = ex[2]
x1 = range(1.0f0.*xlims...,length=100)
x2 = range(1.0f0.*ylims...,length=100)
plt = Plots.contour(
x1, x2, (x, y) -> softmax(jem([x, y]))[target],
fill=true, alpha=0.5, title="Target: $target", cbar=true,
xlims=xlims,
ylims=ylims,
)
Plots.scatter!(X[1,:], X[2,:], color=Int.(labels.refs), group=Int.(labels.refs), alpha=0.5)
Plots.scatter!(
X̂[1,:], X̂[2,:],
color=repeat([target], size(X̂,2)),
group=repeat([target], size(X̂,2)),
shape=:star5, ms=10
)
push!(plts, plt)
end
plt = Plots.plot(plts..., layout=(1, 2), size=(2*500, 400))
display(plt)
end
```
```{julia}
#| output: true
#| echo: false
#| label: fig-losses
#| fig-cap: "Illustration of the smooth size loss and the configurable classification loss."
temp = 0.25
p1 = Plots.contourf(mach.model, mach.fitresult, permutedims(X), labels; plot_set_loss=true, zoom=0, temp=temp)
p2 = Plots.contourf(mach.model, mach.fitresult, permutedims(X), labels; plot_classification_loss=true, zoom=0, temp=temp, clim=nothing, loss_matrix=ones(2,2))
display(Plots.plot(p1, p2, size=(800,320)))
```
```{julia}
Random.seed!(1234)
λ₁ = 0.1
λ₂ = 0.4
λ₃ = 0.4
Λ = [λ₁, λ₂, λ₃]
M = ECCCo.ConformalModel(mach.model, mach.fitresult)
factual_label = levels(labels)[2]
x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = levels(labels)[1]
factual = predict_label(M, counterfactual_data, x_factual)[1]
opt = Flux.Optimise.Descent(0.01)
generator_dict = OrderedDict(
"Wachter" => WachterGenerator(λ = λ₁, opt=opt),
"Schut" => GreedyGenerator(λ = λ₁),
"REVISE" => REVISEGenerator(λ = λ₁, opt=opt),
"ECCCo" => ECCCoGenerator(λ = Λ, opt=opt),
)
ces = Dict{Any,Any}()
plts = []
for (name, generator) in generator_dict
ce = generate_counterfactual(
x_factual, target, counterfactual_data, M, generator;
initialization=:identity,
converge_when=:generator_conditions,
)
plt = Plots.plot(
ce, title=name, alpha=0.2, cbar=false,
axis=nothing, length_out=10, contour_alpha=1.0,
)
if name == "ECCCo"
_X = distance_from_energy(ce, return_conditionals=true)
Plots.scatter!(
_X[1,:],_X[2,:], color=:purple, shape=:star5,
ms=10, label="x̂|$target", alpha=0.1
)
end
push!(plts, plt)
ces[name] = ce
end
plt = Plots.plot(plts..., size=(500,520))
display(plt)
savefig(plt, joinpath(output_images_path, "poc.png"))
```
```{julia}
using Colors
col_pal = palette(:seaborn_colorblind)
Random.seed!(1234)
using CounterfactualExplanations.Generators: ∇
λ₁ = 0.1
λ₂ = 1.0
λ₃ = 0.5
Λ = [λ₁, λ₂, λ₃]
η = 0.01
M = ECCCo.ConformalModel(mach.model, mach.fitresult)
factual_label = levels(labels)[2]
x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = levels(labels)[1]
factual = predict_label(M, counterfactual_data, x_factual)[1]
opt = Flux.Optimise.Descent(η)
generator_dict = OrderedDict(
"Wachter" => WachterGenerator(λ = 0.3, opt=opt),
"ECCCo (no EBM)" => ECCCoGenerator(λ = [λ₁,λ₂,0.0], opt=opt),
"ECCCo (no CP)" => ECCCoGenerator(λ = [λ₁,0.0,λ₃], opt=opt),
"ECCCo" => ECCCoGenerator(λ = Λ, opt=opt),
)
# Gradient field:
function loss_grads(generator, model, ce, x)
x = Float32.(x)
_ce = deepcopy(ce)
_ce.s′ = x
return ∇(generator,M,_ce)
end
meshgrid(x, y) = (repeat(x, outer=length(y)), repeat(y, inner=length(x)))
xlims, ylims = extrema(X, dims=2)
xrange = range(xlims..., length=10)
yrange = range(ylims..., length=10)
x1, x2 = meshgrid(xrange, yrange)
inputs = zip(x1, x2)
# Plot:
ces = Dict{Any,Any}()
plts = []
for (name, generator) in generator_dict
# CE:
ce = generate_counterfactual(
x_factual, target, counterfactual_data, M, generator;
initialization=:identity,
converge_when=:generator_conditions,
)
# Main plot (path):
plt = Plots.plot(
ce, title=name, alpha=0.1, cbar=false,
axis=nothing, length_out=10, contour_alpha=0.0,
legend = false,
palette = col_pal,
)
# Generated samples:
if name ∈ ["ECCCo","ECCCo (no CP)"]
_X = distance_from_energy(ce, return_conditionals=true)
Plots.scatter!(
_X[1,:],_X[2,:], color=col_pal[end-1], shape=:star5,
ms=10, label="x̂|$target", alpha=0.1
)
end
# Gradient field:
u = []
v = []
for (x, y) in inputs
g = -loss_grads(generator, M, ce, [x, y][:,:])
push!(u, η * g[1])
push!(v, η * g[2])
end
Plots.quiver!(x1, x2, quiver=(u, v), color=col_pal[5])
push!(plts, plt)
ces[name] = ce
end
plt = Plots.plot(plts..., size=(500,520))
display(plt)
savefig(plt, joinpath(output_images_path, "poc_gradient_fields.png"))
```
```{julia}
ce = ces["ECCCo"]
using CounterfactualExplanations.Generators: ℓ
# loss
function f1(x)
_ce = deepcopy(ce)
_ce.s′ = x
return - ℓ(_ce.generator, _ce)
end
# cost
function f2(x)
_ce = deepcopy(ce)
_ce.s′ = x
λ = _ce.generator.λ[1]
_loss = _ce.generator.penalty[1]
return - λ * _loss(_ce)
end
# set size
function f3(x)
_ce = deepcopy(ce)
_ce.s′ = x
λ = _ce.generator.λ[2]
_loss = _ce.generator.penalty[2]
return - λ * _loss(_ce)
end
# distance from energy
function f4(x)
_ce = deepcopy(ce)
_ce.s′ = x
λ = _ce.generator.λ[3]
_loss = _ce.generator.penalty[3]
return - λ * _loss(_ce)
end
# Helper function:
meshgrid(x, y) = (repeat(x, outer=length(y)), repeat(y, inner=length(x)))
xlims, ylims = extrema(X, dims=2)
xrange = range(xlims..., length=10)
yrange = range(ylims..., length=10)
x1, x2 = meshgrid(xrange, yrange)
inputs = zip(x1, x2)
u = []
v = []
for (x, y) in inputs
push!(u, scale * gradient(f1, [x, y][:,:])[1][1])
push!(v, scale * gradient(f1, [x, y][:,:])[1][2])
end
Plots.plot(xlims=xlims, ylims=ylims)
Plots.scatter!(counterfactual_data)
Plots.quiver!(x1, x2, quiver=(u, v))
```