```{julia}
include("notebooks/setup.jl")
eval(setup_notebooks);
```

# Synthetic data

```{julia}
#| output: false

# Data:
datasets = Dict(
    :linearly_separable => load_linearly_separable(),
    :overlapping => load_overlapping(),
    :moons => load_moons(),
    :circles => load_circles(),
    :multi_class => load_multi_class(),
)

# Hyperparameters:
cvgs = [0.5, 0.75, 0.95]
temps = [0.01, 0.1, 1.0]
Λ = [0.0, 0.1, 1.0]
l2_λ = 0.1

# Classifiers:
epochs = 250
link_fun = relu
logreg = NeuralNetworkClassifier(builder=MLJFlux.Linear(σ=link_fun), epochs=epochs)
mlp = NeuralNetworkClassifier(builder=MLJFlux.MLP(hidden=(32,), σ=link_fun), epochs=epochs)
ensmbl = EnsembleModel(model=mlp, n=5)
classifiers = Dict(
    # :logreg => logreg,
    :mlp => mlp,
    # :ensmbl => ensmbl,
)

# Search parameters:
target = 2
factual = 1
max_iter = 50
gradient_tol = 1e-2
opt = Descent(0.01)
```

```{julia}
#| echo: false

results = DataFrame()
for (dataname, data) in datasets

    # Data:
    X = table(permutedims(data.X))
    y = data.output_encoder.labels
    x = select_factual(data,rand(1:size(data.X,2)))

    for (clf_name, clf) in classifiers, cov in cvgs

        # Classifier and coverage:
        conf_model = conformal_model(clf; method=:simple_inductive, coverage=cov)
        mach = machine(conf_model, X, y)
        fit!(mach)
        M = CCE.ConformalModel(mach.model, mach.fitresult)

        # Set up CCE:
        factual_label = predict_label(M, data, x)[1]
        target_label = data.y_levels[data.y_levels .!= factual_label][1]

        for λ in Λ, temp in temps

            # CCE for given classifier, coverage, temperature and λ:
            generator = CCEGenerator(temp=temp, λ=[l2_λ,λ], opt=opt)
            @assert predict_label(M, data, x) != target_label
            ce = try
                generate_counterfactual(
                    x, target_label, data, M, generator;
                    initialization=:identity,
                    converge_when=:generator_conditions,
                    gradient_tol=gradient_tol,
                    max_iter=max_iter,
                )
            catch
                missing
            end

            _results = DataFrame(
                dataset = dataname,
                classifier = clf_name,
                coverage = cov,
                temperature = temp,
                λ = λ,
                ce = ce,
                factual = factual_label,
                target = target_label,
            )
            append!(results, _results)

        end

    end

end
```

```{julia}
#| echo: false

function plot_ce(results; dataset=:multi_class, classifier=:mlp, λ=0.1, img_height=300, zoom=-0.2)
    df_plot = results[results.dataset .== dataset,:] |>
        res -> res[res.classifier .== classifier,:] |>
        res -> res[res.λ .== λ,:]
    plts = map(eachrow(df_plot)) do row
        Plots.plot(
            row.ce,
            title="cov: $(row.coverage), temp: $(row.temperature)",
            cbar=false,
            zoom=zoom,
            legend=false,
        )
    end
    nrow = length(cvgs)
    ncol = length(temps)
    _layout = (nrow, ncol)
    Plots.plot(
        plts..., 
        size=img_height.*reverse(_layout), layout=_layout,
        plot_title="λ: $λ, dataset: $dataset, classifier: $classifier",
    )
end
```

```{julia}
#| output: true
#| echo: false

for dataset in keys(datasets)
    Markdown.parse("""### $dataset""")
    for classifier in keys(classifiers)
        Markdown.parse("""#### $classifier""")
        Markdown.parse("""::: {.panel-tabset}""")
        for λ in Λ
            Markdown.parse("""##### λ: $λ""")
            display(plot_ce(results; dataset=dataset, classifier=classifier, λ=λ))
        end
        Markdown.parse(""":::""")
    end
end
```

## Benchmark

```{julia}
generators = Dict(
    :wachter => GenericGenerator(opt=opt, λ=l2_λ),
    :revise => REVISEGenerator(opt=opt, λ=l2_λ),
    :greedy => GreedyGenerator(opt=opt),
)

# Untrained Models:
models = Dict("cov$(cov)" => CCE.ConformalModel(conformal_model(mlp; method=:simple_inductive, coverage=cov)) for cov in cvgs)
```

```{julia}
bmks = []
measures = [
    CounterfactualExplanations.distance,
    CCE.distance_from_energy,
    CCE.distance_from_targets,
    CounterfactualExplanations.validity,
]
for (dataname, dataset) in datasets
    for λ in Λ, temp in temps
        _generators = deepcopy(generators)
        _generators[:cce] = CCEGenerator(temp=temp, λ=[l2_λ,λ], opt=opt)
        _generators[:energy] = CCE.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt)
        _generators[:target] = CCE.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt)
        bmk = benchmark(
            dataset; 
            models=deepcopy(models), 
            generators=_generators, 
            measure=measures,
            suppress_training=false, dataname=dataname,
            n_individuals=5,
            initialization=:identity,
        )
        bmk.evaluation.λ .= λ
        bmk.evaluation.temperature .= temp
        push!(bmks, bmk)
    end
end
bmk = reduce(vcat, bmks)
```

```{julia}
CSV.write(joinpath(output_path, "synthetic_benchmark.csv"), bmk())
```

```{julia}
#| output: true
#| echo: false

df = bmk()
for dataname ∈ sort(unique(df.dataname))
    Markdown.parse("""### $dataset""")
    df_ = df[df.dataname .== dataname, :]
    for λ in Λ, temp in temps
        Markdown.parse("""#### λ: $λ""")
        df_plot = df_[df_.λ .== λ, :]
        df_plot = df_plot[df_plot.temperature .== temp, :]
        plt = AlgebraOfGraphics.data(df_plot) * visual(BoxPlot) * 
            mapping(:generator, :value, row=:variable, col=:model, color=:generator)
        plt = draw(
            plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=200, height=180), 
            facet=(; linkyaxes=:minimal)
        )   
        display(plt)
    end
end
```