poc.qmd 4.90 KiB
```{julia}
include("$(pwd())/notebooks/setup.jl")
eval(setup_notebooks)
```
```{julia}
# Counteractual data:
n_obs = 1000
counterfactual_data = load_blobs(n_obs; cluster_std=0.1, center_box=(-1. => 1.))
X = counterfactual_data.X
y = counterfactual_data.y
labels = counterfactual_data.output_encoder.labels
input_dim, nobs = size(X)
batch_size = Int(round(nobs/10))
epochs = 100
```
```{julia}
Plots.plot()
display(Plots.scatter!(counterfactual_data))
```
```{julia}
𝒟x = Normal()
𝒟y = Categorical(ones(2) ./ 2)
sampler = ConditionalSampler(𝒟x, 𝒟y, input_size=size(X)[1:end-1], batch_size=50)
n_hidden = 16
clf = JointEnergyClassifier(
sampler;
builder=MLJFlux.MLP(
hidden=(n_hidden, n_hidden, n_hidden),
σ=Flux.swish
),
batch_size=batch_size,
finaliser=x -> x,
loss=Flux.Losses.logitcrossentropy,
jem_training_params=(
α=[1.0,1.0,1e-1],
verbosity=10,
),
epochs=epochs,
sampling_steps=30,
)
```
```{julia}
method = :simple_inductive
cov = .95
conf_model = conformal_model(clf; method=method, coverage=cov)
mach = machine(conf_model, table(permutedims(X)), labels)
fit!(mach)
Serialization.serialize(joinpath(output_path,"poc_model.jls"), mach)
```
```{julia}
#| echo: false
niter = 1000
jem = mach.model.model.jem
batch_size = mach.model.model.batch_size
X = Float32.(matrix(X))
if typeof(jem.sampler) <: ConditionalSampler
plts = []
for target in 1:2
X̂ = generate_conditional_samples(jem, batch_size, target; niter=niter)
ex = extrema(hcat(X,X̂), dims=2)
xlims = ex[1]
ylims = ex[2]
x1 = range(1.0f0.*xlims...,length=100)
x2 = range(1.0f0.*ylims...,length=100)
plt = Plots.contour(
x1, x2, (x, y) -> softmax(jem([x, y]))[target],
fill=true, alpha=0.5, title="Target: $target", cbar=true,
xlims=xlims,
ylims=ylims,
)
Plots.scatter!(X[1,:], X[2,:], color=Int.(labels.refs), group=Int.(labels.refs), alpha=0.5)
Plots.scatter!(
X̂[1,:], X̂[2,:],
color=repeat([target], size(X̂,2)),
group=repeat([target], size(X̂,2)),
shape=:star5, ms=10
)
push!(plts, plt)
end
plt = Plots.plot(plts..., layout=(1, 2), size=(2*500, 400))
display(plt)
end
```
```{julia}
Random.seed!(1234)
λ₁ = 0.1
λ₂ = 0.4
λ₃ = 0.4
Λ = [λ₁, λ₂, λ₃]
M = ECCCo.ConformalModel(mach.model, mach.fitresult)
factual_label = levels(labels)[2]
x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = levels(labels)[1]
factual = predict_label(M, counterfactual_data, x_factual)[1]
opt = Flux.Optimise.Descent(0.01)
generator_dict = OrderedDict(
"Wachter" => WachterGenerator(λ = λ₁, opt=opt),
"Schut" => GreedyGenerator(λ = λ₁),
"REVISE" => REVISEGenerator(λ = λ₁, opt=opt),
"ECCCo" => ECCCoGenerator(λ = Λ, opt=opt),
)
ces = Dict{Any,Any}()
plts = []
for (name, generator) in generator_dict
ce = generate_counterfactual(
x_factual, target, counterfactual_data, M, generator;
initialization=:identity,
converge_when=:generator_conditions,
)
plt = Plots.plot(
ce, title=name, alpha=0.2, cbar=false,
axis=nothing, length_out=10, contour_alpha=1.0,
)
if name == "ECCCo"
_X = distance_from_energy(ce, return_conditionals=true)
Plots.scatter!(
_X[1,:],_X[2,:], color=:purple, shape=:star5,
ms=10, label="x̂|$target", alpha=0.1
)
end
push!(plts, plt)
ces[name] = ce
end
plt = Plots.plot(plts..., size=(500,520))
display(plt)
savefig(plt, joinpath(output_images_path, "poc.png"))
```
```{julia}
ce = ces["ECCCo"]
using CounterfactualExplanations.Generators: ℓ
# loss
function f1(x)
_ce = deepcopy(ce)
_ce.s′ = x
return - ℓ(_ce.generator, _ce)
end
# cost
function f2(x)
_ce = deepcopy(ce)
_ce.s′ = x
λ = _ce.generator.λ[1]
_loss = _ce.generator.penalty[1]
return - λ * _loss(_ce)
end
# set size
function f3(x)
_ce = deepcopy(ce)
_ce.s′ = x
λ = _ce.generator.λ[2]
_loss = _ce.generator.penalty[2]
return - λ * _loss(_ce)
end
# distance from energy
function f4(x)
_ce = deepcopy(ce)
_ce.s′ = x
λ = _ce.generator.λ[3]
_loss = _ce.generator.penalty[3]
return - λ * _loss(_ce)
end
# Helper function:
meshgrid(x, y) = (repeat(x, outer=length(y)), repeat(y, inner=length(x)))
xlims, ylims = extrema(X, dims=2)
xrange = range(xlims..., length=10)
yrange = range(ylims..., length=10)
x1, x2 = meshgrid(xrange, yrange)
inputs = zip(x, y)
u = []
v = []
scale = 0.1
for (x, y) in inputs
push!(u, scale * gradient(f1, [x, y][:,:])[1][1])
push!(v, scale * gradient(f1, [x, y][:,:])[1][2])
end
Plots.plot(xlims=xlims, ylims=ylims)
Plots.scatter!(counterfactual_data)
Plots.quiver!(x1, x2, quiver=(u, v))
```