```{julia}
include("notebooks/setup.jl")
eval(setup_notebooks)
```


```{julia}
# Counteractual data:
counterfactual_data = load_moons(2500; noise=0.1)
X = counterfactual_data.X
y = counterfactual_data.y
labels = counterfactual_data.output_encoder.labels
input_dim, nobs = size(X)
batch_size = Int(round(nobs/10))
epochs = 500
```

```{julia}
Plots.plot()
display(Plots.scatter!(counterfactual_data))
```

```{julia}
𝒟x = Normal()
𝒟y = Categorical(ones(2) ./ 2)
sampler = ConditionalSampler(𝒟x, 𝒟y, input_size=size(X)[1:end-1], batch_size=10)
n_hidden = 32
clf = JointEnergyClassifier(
    sampler;
    builder=MLJFlux.MLP(
        hidden=(n_hidden, n_hidden, n_hidden), 
        σ=Flux.relu
    ),
    batch_size=batch_size,
    finaliser=x -> x,
    loss=Flux.Losses.logitcrossentropy,
    jem_training_params=(
        α=[1.0,1.0,1e-1],
        verbosity=50,
    ),
    epochs=epochs,
    sampling_steps=50,
)
```


```{julia}
method = :simple_inductive
cov = .95
conf_model = conformal_model(clf; method=method, coverage=cov)
mach = machine(conf_model, table(permutedims(X)), labels)
fit!(mach)
Serialization.serialize(joinpath(output_path,"poc_model.jls"), mach)
```


```{julia}
#| echo: false
niter = 1000
jem = mach.model.model.jem
batch_size = mach.model.model.batch_size
X = Float32.(matrix(X))
if typeof(jem.sampler) <: ConditionalSampler
    
    plts = []
    for target in 1:2
        X̂ = generate_conditional_samples(jem, batch_size, target; niter=niter) 
        ex = extrema(hcat(X,X̂), dims=2)
        xlims = ex[1]
        ylims = ex[2]
        x1 = range(1.0f0.*xlims...,length=100)
        x2 = range(1.0f0.*ylims...,length=100)
        plt = Plots.contour(
            x1, x2, (x, y) -> softmax(jem([x, y]))[target], 
            fill=true, alpha=0.5, title="Target: $target", cbar=true,
            xlims=xlims,
            ylims=ylims,
        )
        Plots.scatter!(X[1,:], X[2,:], color=Int.(labels.refs), group=Int.(labels.refs), alpha=0.5)
        Plots.scatter!(
            X̂[1,:], X̂[2,:], 
            color=repeat([target], size(X̂,2)), 
            group=repeat([target], size(X̂,2)), 
            shape=:star5, ms=10
        )
        push!(plts, plt)
    end
    plt = Plots.plot(plts..., layout=(1, 2), size=(2*500, 400))
    display(plt)
end
```


```{julia}
Random.seed!(1234)

λ₁ = 0.1
λ₂ = 0.5
λ₃ = 0.5
Λ = [λ₁, λ₂, λ₃]

M = ECCCo.ConformalModel(mach.model, mach.fitresult)
factual_label =  levels(labels)[1]
x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target =  levels(labels)[2]
factual = predict_label(M, counterfactual_data, x_factual)[1]
γ = 0.5

generator_dict = OrderedDict(
    "Wachter (γ=$γ)" => WachterGenerator(λ = λ₁),
    "Schut" => GreedyGenerator(λ = λ₁),
    "REVISE" => REVISEGenerator(λ = λ₁),
    "ECCCo" => ECCCoGenerator(λ = Λ),
)

ces = Dict{Any,Any}()
plts = []
for (name, generator) in generator_dict
    conv_when = name == "Wachter (γ=$γ)" ? :decision_threshold : :generator_conditions
    ce = generate_counterfactual(
        x_factual, target, counterfactual_data, M, generator;
        initialization=:identity, 
        converge_when=conv_when,
        decision_threshold=γ,
    )
    plt = Plots.plot(ce, title=name, alpha=0.2, cbar=false, axis=nothing)
    if name == "ECCCo"
        _X = distance_from_energy(ce, return_conditionals=true)
        Plots.scatter!(
            _X[1,:],_X[2,:], color=:purple, shape=:star5, 
            ms=10, label="x̂|$target", alpha=0.5
        )
    end
    push!(plts, plt)
    ces[name] = ce
end
plt = Plots.plot(plts..., size=(500,520))
display(plt)
savefig(plt, joinpath(output_images_path, "poc.png"))
```


```{julia}
ce = ces["ECCCo"]
```