Skip to content
Snippets Groups Projects
poc.qmd 4.90 KiB
```{julia}
include("$(pwd())/notebooks/setup.jl")
eval(setup_notebooks)
```


```{julia}
# Counteractual data:
n_obs = 1000
counterfactual_data = load_blobs(n_obs; cluster_std=0.1, center_box=(-1. => 1.))
X = counterfactual_data.X
y = counterfactual_data.y
labels = counterfactual_data.output_encoder.labels
input_dim, nobs = size(X)
batch_size = Int(round(nobs/10))
epochs = 100
```

```{julia}
Plots.plot()
display(Plots.scatter!(counterfactual_data))
```

```{julia}
𝒟x = Normal()
𝒟y = Categorical(ones(2) ./ 2)
sampler = ConditionalSampler(𝒟x, 𝒟y, input_size=size(X)[1:end-1], batch_size=50)
n_hidden = 16
clf = JointEnergyClassifier(
    sampler;
    builder=MLJFlux.MLP(
        hidden=(n_hidden, n_hidden, n_hidden), 
        σ=Flux.swish
    ),
    batch_size=batch_size,
    finaliser=x -> x,
    loss=Flux.Losses.logitcrossentropy,
    jem_training_params=(
        α=[1.0,1.0,1e-1],
        verbosity=10,
    ),
    epochs=epochs,
    sampling_steps=30,
)
```


```{julia}
method = :simple_inductive
cov = .95
conf_model = conformal_model(clf; method=method, coverage=cov)
mach = machine(conf_model, table(permutedims(X)), labels)
fit!(mach)
Serialization.serialize(joinpath(output_path,"poc_model.jls"), mach)
```


```{julia}
#| echo: false
niter = 1000
jem = mach.model.model.jem
batch_size = mach.model.model.batch_size
X = Float32.(matrix(X))
if typeof(jem.sampler) <: ConditionalSampler
    
    plts = []
    for target in 1:2
        X̂ = generate_conditional_samples(jem, batch_size, target; niter=niter) 
        ex = extrema(hcat(X,X̂), dims=2)
        xlims = ex[1]
        ylims = ex[2]
        x1 = range(1.0f0.*xlims...,length=100)
        x2 = range(1.0f0.*ylims...,length=100)
        plt = Plots.contour(
            x1, x2, (x, y) -> softmax(jem([x, y]))[target], 
            fill=true, alpha=0.5, title="Target: $target", cbar=true,
            xlims=xlims,
            ylims=ylims,
        )
        Plots.scatter!(X[1,:], X[2,:], color=Int.(labels.refs), group=Int.(labels.refs), alpha=0.5)
        Plots.scatter!(
            X̂[1,:], X̂[2,:], 
            color=repeat([target], size(X̂,2)), 
            group=repeat([target], size(X̂,2)), 
            shape=:star5, ms=10
        )
        push!(plts, plt)
    end
    plt = Plots.plot(plts..., layout=(1, 2), size=(2*500, 400))
    display(plt)
end
```


```{julia}
Random.seed!(1234)

λ₁ = 0.1
λ₂ = 0.4
λ₃ = 0.4
Λ = [λ₁, λ₂, λ₃]

M = ECCCo.ConformalModel(mach.model, mach.fitresult)
factual_label =  levels(labels)[2]
x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target =  levels(labels)[1]
factual = predict_label(M, counterfactual_data, x_factual)[1]

opt = Flux.Optimise.Descent(0.01)

generator_dict = OrderedDict(
    "Wachter" => WachterGenerator(λ = λ₁, opt=opt),
    "Schut" => GreedyGenerator(λ = λ₁),
    "REVISE" => REVISEGenerator(λ = λ₁, opt=opt),
    "ECCCo" => ECCCoGenerator(λ = Λ, opt=opt),
)

ces = Dict{Any,Any}()
plts = []
for (name, generator) in generator_dict
    ce = generate_counterfactual(
        x_factual, target, counterfactual_data, M, generator;
        initialization=:identity, 
        converge_when=:generator_conditions,
    )
    plt = Plots.plot(
        ce, title=name, alpha=0.2, cbar=false, 
        axis=nothing, length_out=10, contour_alpha=1.0,
    )
    if name == "ECCCo"
        _X = distance_from_energy(ce, return_conditionals=true)
        Plots.scatter!(
            _X[1,:],_X[2,:], color=:purple, shape=:star5, 
            ms=10, label="x̂|$target", alpha=0.1
        )
    end
    push!(plts, plt)
    ces[name] = ce
end
plt = Plots.plot(plts..., size=(500,520))
display(plt)
savefig(plt, joinpath(output_images_path, "poc.png"))
```


```{julia}
ce = ces["ECCCo"]
using CounterfactualExplanations.Generators: ℓ

# loss
function f1(x)
    _ce = deepcopy(ce)
    _ce.s′ = x
    return - ℓ(_ce.generator, _ce)
end

# cost
function f2(x)
    _ce = deepcopy(ce)
    _ce.s′ = x
    λ = _ce.generator.λ[1]
    _loss = _ce.generator.penalty[1]
    return - λ * _loss(_ce)
end

# set size
function f3(x)
    _ce = deepcopy(ce)
    _ce.s′ = x
    λ = _ce.generator.λ[2]
    _loss = _ce.generator.penalty[2]
    return - λ * _loss(_ce)
end

# distance from energy
function f4(x)
    _ce = deepcopy(ce)
    _ce.s′ = x
    λ = _ce.generator.λ[3]
    _loss = _ce.generator.penalty[3]
    return - λ * _loss(_ce)
end

# Helper function:
meshgrid(x, y) = (repeat(x, outer=length(y)), repeat(y, inner=length(x)))

xlims, ylims = extrema(X, dims=2)
xrange = range(xlims..., length=10)
yrange = range(ylims..., length=10)

x1, x2 = meshgrid(xrange, yrange)
inputs = zip(x, y)

u = []
v = []

scale = 0.1
for (x, y) in inputs
    push!(u, scale * gradient(f1, [x, y][:,:])[1][1])
    push!(v, scale * gradient(f1, [x, y][:,:])[1][2])
end

Plots.plot(xlims=xlims, ylims=ylims)
Plots.scatter!(counterfactual_data)
Plots.quiver!(x1, x2, quiver=(u, v))
```