```{julia}
using CCE
using ConformalPrediction
using CounterfactualExplanations
using CounterfactualExplanations.Data
using Flux
using MLJBase
using MLJFlux
using Plots
```

```{julia}
counterfactual_data = load_linearly_separable()
```

```{julia}
builder = MLJFlux.@builder Chain(
    Dense(n_in, 32, relu),
    Dense(32, n_out)
)
clf = NeuralNetworkClassifier(builder=builder, epochs=100)
```


```{julia}
X = table(permutedims(counterfactual_data.X))
y =  counterfactual_data.output_encoder.labels
conf_model = conformal_model(clf; method=:simple_inductive)
mach = machine(conf_model, X, y)
fit!(mach)
```

## Counterfactual Explanation

```{julia}
M = CCE.ConformalModel(conf_model, mach.fitresult)
generator = CCE.ConformalGenerator()
```

```{julia}
x = select_factual(counterfactual_data,rand(1:size(counterfactual_data.X,2)))
y_factual = predict_label(M, counterfactual_data, x)[1]
target = counterfactual_data.y_levels[counterfactual_data.y_levels .!= y_factual][1]
```

```{julia}
p1 = contourf(mach.model, mach.fitresult, X, y; plot_classification_loss=true, target=target, zoom=0)
p2 = contourf(mach.model, mach.fitresult, X, y; plot_set_loss=true, zoom=0)
plot(p1, p2, size=(800,320))
```

```{julia}
ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
```