Skip to content
Snippets Groups Projects
poc.qmd 8.39 KiB
Newer Older
pat-alt's avatar
pat-alt committed
```{julia}
Pat Alt's avatar
Pat Alt committed
include("$(pwd())/notebooks/setup.jl")
pat-alt's avatar
pat-alt committed
eval(setup_notebooks)
```


```{julia}
# Counteractual data:
Pat Alt's avatar
Pat Alt committed
n_obs = 1000
counterfactual_data = load_blobs(n_obs; cluster_std=0.1, center_box=(-1. => 1.))
X = counterfactual_data.X
y = counterfactual_data.y
labels = counterfactual_data.output_encoder.labels
input_dim, nobs = size(X)
batch_size = Int(round(nobs/10))
epochs = 100
```

```{julia}
Plots.plot()
display(Plots.scatter!(counterfactual_data))
```

```{julia}
𝒟x = Normal()
𝒟y = Categorical(ones(2) ./ 2)
sampler = ConditionalSampler(𝒟x, 𝒟y, input_size=size(X)[1:end-1], batch_size=50)
n_hidden = 16
clf = JointEnergyClassifier(
    sampler;
    builder=MLJFlux.MLP(
        hidden=(n_hidden, n_hidden, n_hidden), 
        σ=Flux.swish
    batch_size=batch_size,
    finaliser=Flux.softmax,
    loss=Flux.Losses.crossentropy,
    jem_training_params=(
        α=[1.0,1.0,1e-1],
        verbosity=10,
    epochs=epochs,
    sampling_steps=30,
)
```


```{julia}
method = :simple_inductive
cov = .95
conf_model = conformal_model(clf; method=method, coverage=cov)
mach = machine(conf_model, table(permutedims(X)), labels)
fit!(mach)
Serialization.serialize(joinpath(output_path,"poc_model.jls"), mach)
```


```{julia}
#| echo: false
niter = 1000
jem = mach.model.model.jem
batch_size = mach.model.model.batch_size
X = Float32.(matrix(X))
if typeof(jem.sampler) <: ConditionalSampler
    
    plts = []
    for target in 1:2
        X̂ = generate_conditional_samples(jem, batch_size, target; niter=niter) 
        ex = extrema(hcat(X,X̂), dims=2)
        xlims = ex[1]
        ylims = ex[2]
        x1 = range(1.0f0.*xlims...,length=100)
        x2 = range(1.0f0.*ylims...,length=100)
        plt = Plots.contour(
            x1, x2, (x, y) -> softmax(jem([x, y]))[target], 
            fill=true, alpha=0.5, title="Target: $target", cbar=true,
            xlims=xlims,
            ylims=ylims,
        )
        Plots.scatter!(X[1,:], X[2,:], color=Int.(labels.refs), group=Int.(labels.refs), alpha=0.5)
        Plots.scatter!(
            X̂[1,:], X̂[2,:], 
            color=repeat([target], size(X̂,2)), 
            group=repeat([target], size(X̂,2)), 
            shape=:star5, ms=10
        )
        push!(plts, plt)
    end
    plt = Plots.plot(plts..., layout=(1, 2), size=(2*500, 400))
    display(plt)
end
```


```{julia}
#| output: true
#| echo: false
#| label: fig-losses
#| fig-cap: "Illustration of the smooth size loss and the configurable classification loss."

temp = 0.25
p1 = Plots.contourf(mach.model, mach.fitresult, permutedims(X), labels; plot_set_loss=true, zoom=0, temp=temp)
p2 = Plots.contourf(mach.model, mach.fitresult, permutedims(X), labels; plot_classification_loss=true, zoom=0, temp=temp, clim=nothing, loss_matrix=ones(2,2))
display(Plots.plot(p1, p2, size=(800,320)))
```

```{julia}
Pat Alt's avatar
Pat Alt committed
Random.seed!(1234)

λ₁ = 0.1
Pat Alt's avatar
Pat Alt committed
λ₂ = 0.4
λ₃ = 0.4
Pat Alt's avatar
Pat Alt committed
Λ = [λ₁, λ₂, λ₃]

M = ECCCo.ConformalModel(mach.model, mach.fitresult)
factual_label =  levels(labels)[2]
Pat Alt's avatar
Pat Alt committed
x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target =  levels(labels)[1]
Pat Alt's avatar
Pat Alt committed
factual = predict_label(M, counterfactual_data, x_factual)[1]

Pat Alt's avatar
Pat Alt committed
opt = Flux.Optimise.Descent(0.01)

Pat Alt's avatar
Pat Alt committed
generator_dict = OrderedDict(
Pat Alt's avatar
Pat Alt committed
    "Wachter" => WachterGenerator(λ = λ₁, opt=opt),
Pat Alt's avatar
Pat Alt committed
    "Schut" => GreedyGenerator(λ = λ₁),
Pat Alt's avatar
Pat Alt committed
    "REVISE" => REVISEGenerator(λ = λ₁, opt=opt),
    "ECCCo" => ECCCoGenerator(λ = Λ, opt=opt),
Pat Alt's avatar
Pat Alt committed
)

ces = Dict{Any,Any}()
plts = []
for (name, generator) in generator_dict
    ce = generate_counterfactual(
        x_factual, target, counterfactual_data, M, generator;
        initialization=:identity, 
        converge_when=:generator_conditions,
Pat Alt's avatar
Pat Alt committed
    plt = Plots.plot(
        ce, title=name, alpha=0.2, cbar=false, 
        axis=nothing, length_out=10, contour_alpha=1.0,
    )
Pat Alt's avatar
Pat Alt committed
    if name == "ECCCo"
        _X = distance_from_energy(ce, return_conditionals=true)
        Plots.scatter!(
            _X[1,:],_X[2,:], color=:purple, shape=:star5, 
            ms=10, label="x̂|$target", alpha=0.1
Pat Alt's avatar
Pat Alt committed
        )
    end
    push!(plts, plt)
    ces[name] = ce
end
plt = Plots.plot(plts..., size=(500,520))
display(plt)
savefig(plt, joinpath(output_images_path, "poc.png"))
```

pat-alt's avatar
pat-alt committed
```{julia}
using Colors
col_pal = palette(:seaborn_colorblind)
Random.seed!(1234)
using CounterfactualExplanations.Generators: ∇

λ₁ = 0.1
pat-alt's avatar
pat-alt committed
λ₃ = 0.5
Λ = [λ₁, λ₂, λ₃]
η = 0.01

M = ECCCo.ConformalModel(mach.model, mach.fitresult)
factual_label =  levels(labels)[2]
x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target =  levels(labels)[1]
factual = predict_label(M, counterfactual_data, x_factual)[1]

opt = Flux.Optimise.Descent(η)

generator_dict = OrderedDict(
    "Wachter" => WachterGenerator(λ = 0.3, opt=opt),
    "ECCCo (no EBM)" => ECCCoGenerator(λ = [λ₁,λ₂,0.0], opt=opt),
    "ECCCo (no CP)" => ECCCoGenerator(λ = [λ₁,0.0,λ₃], opt=opt),
    "ECCCo" => ECCCoGenerator(λ = Λ, opt=opt),
)

# Gradient field:
function loss_grads(generator, model, ce, x)
    x = Float32.(x)
    _ce = deepcopy(ce)
    _ce.s′ = x
    return ∇(generator,M,_ce)
end

meshgrid(x, y) = (repeat(x, outer=length(y)), repeat(y, inner=length(x)))

xlims, ylims = extrema(X, dims=2)
xrange = range(xlims..., length=10)
yrange = range(ylims..., length=10)
x1, x2 = meshgrid(xrange, yrange)
inputs = zip(x1, x2)

function arrow0!(x, y, u, v; as=0.2, lw=1, lc=:black, la=1)
    nuv = sqrt(u^2 + v^2)
    v1, v2 = [u;v] / nuv,  [-v;u] / nuv
    v4 = (3*v1 + v2)/3.1623  # sqrt(10) to get unit vector
    v5 = v4 - 2*(v4'*v2)*v2
    v4, v5 = as*nuv*v4, as*nuv*v5
    println(v4)
    println(v5)
    Plots.plot!([x,x+u], [y,y+v], lw=lw, lc=lc, la=la)
    Plots.plot!([x+u,x+u-v5[1]], [y+v,y+v-v5[2]], lw=lw, lc=lc, la=la)
    Plots.plot!([x+u,x+u-v4[1]], [y+v,y+v-v4[2]], lw=lw, lc=lc, la=la)
end

GR.setarrowsize(0.5)
pat-alt's avatar
pat-alt committed
# Plot:
ces = Dict{Any,Any}()
plts = []
for (name, generator) in generator_dict

    # CE:
    ce = generate_counterfactual(
        x_factual, target, counterfactual_data, M, generator;
        initialization=:identity, 
        converge_when=:generator_conditions,
    )

    # Main plot (path):
    plt = Plots.plot(
        ce, title=name, alpha=0.1, cbar=false, 
        axis=nothing, length_out=10, contour_alpha=0.0,
        legend = false,
        palette = col_pal,
    )

    # Generated samples:
    if name ∈ ["ECCCo","ECCCo (no CP)"]
        _X = distance_from_energy(ce, return_conditionals=true)
        Plots.scatter!(
            _X[1,:],_X[2,:], color=col_pal[end-1], shape=:star5, 
            ms=10, label="x̂|$target", alpha=0.1
        )
    end

    # Gradient field:
    u = []
    v = []
    for (x, y) in inputs
        g = -loss_grads(generator, M, ce, [x, y][:,:])
        push!(u, η * g[1])
        push!(v, η * g[2])
    end
    Plots.quiver!(x1, x2, quiver=(u, v), color=col_pal[5])
    # arrow0!.(x1, x2, u, v; as=0.2, lw=1.0,lc=col_pal[5], la=1)
pat-alt's avatar
pat-alt committed
    push!(plts, plt)
    ces[name] = ce
end
plt = Plots.plot(plts...; size=(panel_height*length(plts),panel_height), layout=(1,length(plts)), dpi=300)
# plt = Plots.plot(plts..., size=(1000,250), layout=(1,4), dpi=300)
pat-alt's avatar
pat-alt committed
display(plt)
savefig(plt, joinpath(output_images_path, "poc_gradient_fields.png"))
```


Pat Alt's avatar
Pat Alt committed
```{julia}
ce = ces["ECCCo"]
using CounterfactualExplanations.Generators: ℓ

# loss
function f1(x)
    _ce = deepcopy(ce)
    _ce.s′ = x
    return - ℓ(_ce.generator, _ce)
end

# cost
function f2(x)
    _ce = deepcopy(ce)
    _ce.s′ = x
    λ = _ce.generator.λ[1]
    _loss = _ce.generator.penalty[1]
    return - λ * _loss(_ce)
end

# set size
function f3(x)
    _ce = deepcopy(ce)
    _ce.s′ = x
    λ = _ce.generator.λ[2]
    _loss = _ce.generator.penalty[2]
    return - λ * _loss(_ce)
end

# distance from energy
function f4(x)
    _ce = deepcopy(ce)
    _ce.s′ = x
    λ = _ce.generator.λ[3]
    _loss = _ce.generator.penalty[3]
    return - λ * _loss(_ce)
end

# Helper function:
meshgrid(x, y) = (repeat(x, outer=length(y)), repeat(y, inner=length(x)))

xlims, ylims = extrema(X, dims=2)
xrange = range(xlims..., length=10)
yrange = range(ylims..., length=10)

x1, x2 = meshgrid(xrange, yrange)
inputs = zip(x1, x2)
pat-alt's avatar
pat-alt committed

for (x, y) in inputs
Pat Alt's avatar
Pat Alt committed
    push!(u, scale * gradient(f1, [x, y][:,:])[1][1])
    push!(v, scale * gradient(f1, [x, y][:,:])[1][2])
end

Plots.plot(xlims=xlims, ylims=ylims)
Plots.scatter!(counterfactual_data)
Plots.quiver!(x1, x2, quiver=(u, v))