diff --git a/CITATION.bib b/CITATION.bib index 752944a71c777c0c19921312b6d83677ded8046e..860cec364b16c9e9ac0b66efd847685a78cb4027 100644 --- a/CITATION.bib +++ b/CITATION.bib @@ -1,7 +1,7 @@ -@misc{ECCCE.jl, +@misc{ECCCo.jl, author = {Patrick Altmeyer}, - title = {ECCCE.jl}, - url = {https://github.com/pat-alt/ECCCE.jl}, + title = {ECCCo.jl}, + url = {https://github.com/pat-alt/ECCCo.jl}, version = {v0.1.0}, year = {2023}, month = {2} diff --git a/Project.toml b/Project.toml index 647598d4a3b7da049887eb85135a2e877ac960dc..0d31db08e85cf6fc7d354ba6ab3280cc28bdc1f2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,4 +1,4 @@ -name = "ECCCE" +name = "ECCCo" uuid = "0232c203-4013-4b0d-ad96-43e3e11ac3bf" authors = ["Patrick Altmeyer"] version = "0.1.0" diff --git a/README.md b/README.md index c599843e1c23d5e35d082b8b9dacad5e971aa00e..dff821d389ba5b0e9fddee89ca34348962c07193 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,3 @@ -# ECCCE +# ECCCo -[](https://github.com/pat-alt/ECCCE.jl/actions/workflows/CI.yml?query=branch%3Amain) +[](https://github.com/pat-alt/ECCCo.jl/actions/workflows/CI.yml?query=branch%3Amain) diff --git a/_freeze/dev/proposal/execute-results/html.json b/_freeze/dev/proposal/execute-results/html.json index 195e16d6cc088edd81dae8c7a7f5ee331e48f895..2dc9aae26865c05550c97ca996684cf942be0af9 100644 --- a/_freeze/dev/proposal/execute-results/html.json +++ b/_freeze/dev/proposal/execute-results/html.json @@ -1,7 +1,7 @@ { "hash": "d7b4f9bf7f4bff7ce610fc8be4dcfb8b", "result": { - "markdown": "---\ntitle: High-Fidelity Counterfactual Explanations through Conformal Prediction\nsubtitle: Research Proposal\nabstract: |\n We propose Conformal Counterfactual Explanations: an effortless and rigorous way to produce realistic and faithful Counterfactual Explanations using Conformal Prediction. To address the need for realistic counterfactuals, existing work has primarily relied on separate generative models to learn the data-generating process. While this is an effective way to produce plausible and model-agnostic counterfactual explanations, it not only introduces a significant engineering overhead but also reallocates the task of creating realistic model explanations from the model itself to the generative model. Recent work has shown that there is no need for any of this when working with probabilistic models that explicitly quantify their own uncertainty. Unfortunately, most models used in practice still do not fulfil that basic requirement, in which case we would like to have a way to quantify predictive uncertainty in a post-hoc fashion.\n---\n\n\n\n## Motivation\n\nCounterfactual Explanations are a powerful, flexible and intuitive way to not only explain black-box models but also enable affected individuals to challenge them through the means of Algorithmic Recourse. \n\n### Counterfactual Explanations or Adversarial Examples?\n\nMost state-of-the-art approaches to generating Counterfactual Explanations (CE) rely on gradient descent in the feature space. The key idea is to perturb inputs $x\\in\\mathcal{X}$ into a black-box model $f: \\mathcal{X} \\mapsto \\mathcal{Y}$ in order to change the model output $f(x)$ to some pre-specified target value $t\\in\\mathcal{Y}$. Formally, this boils down to defining some loss function $\\ell(f(x),t)$ and taking gradient steps in the minimizing direction. The so-generated counterfactuals are considered valid as soon as the predicted label matches the target label. A stripped-down counterfactual explanation is therefore little different from an adversarial example. In @fig-adv, for example, generic counterfactual search as in @wachter2017counterfactual has been applied to MNIST data.\n\n\n\n\n\n{#fig-adv}\n\nThe crucial difference between adversarial examples and counterfactuals is one of intent. While adversarial examples are typically intended to go unnoticed, counterfactuals in the context of Explainable AI are generally sought to be \"plausible\", \"realistic\" or \"feasible\". To fulfil this latter goal, researchers have come up with a myriad of ways. @joshi2019realistic were among the first to suggest that instead of searching counterfactuals in the feature space, we can instead traverse a latent embedding learned by a surrogate generative model. Similarly, @poyiadzi2020face use density ... Finally, @karimi2021algorithmic argues that counterfactuals should comply with the causal model that generates them [CHECK IF WE CAN PHASE THIS LIKE THIS]. Other related approaches include ... All of these different approaches have a common goal: they aim to ensure that the generated counterfactuals comply with the (learned) data-generating process (DGB). \n\n::: {#def-plausible}\n\n## Plausible Counterfactuals\n\nFormally, if $x \\sim \\mathcal{X}$ and for the corresponding counterfactual we have $x^{\\prime}\\sim\\mathcal{X}^{\\prime}$, then for $x^{\\prime}$ to be considered a plausible counterfactual, we need: $\\mathcal{X} \\approxeq \\mathcal{X}^{\\prime}$.\n\n:::\n\nIn the context of Algorithmic Recourse, it makes sense to strive for plausible counterfactuals, since anything else would essentially require individuals to move to out-of-distribution states. But it is worth noting that our ambition to meet this goal, may have implications on our ability to faithfully explain the behaviour of the underlying black-box model (arguably our principal goal). By essentially decoupling the task of learning plausible representations of the data from the model itself, we open ourselves up to vulnerabilities. Using a separate generative model to learn $\\mathcal{X}$, for example, has very serious implications for the generated counterfactuals. @fig-latent compares the results of applying REVISE [@joshi2019realistic] to MNIST data using two different Variational Auto-Encoders: while the counterfactual generated using an expressive (strong) VAE is compelling, the result relying on a less expressive (weak) VAE is not even valid. In this latter case, the decoder step of the VAE fails to yield values in $\\mathcal{X}$ and hence the counterfactual search in the learned latent space is doomed. \n\n{#fig-latent}\n\n> Here it would be nice to have another example where we poison the data going into the generative model to hide biases present in the data (e.g. Boston housing).\n\n- Latent can be manipulated: \n - train biased model\n - train VAE with biased variable removed/attacked (use Boston housing dataset)\n - hypothesis: will generate bias-free explanations\n\n### From Plausible to High-Fidelity Counterfactuals {#sec-fidelity}\n\nIn light of the findings, we propose to generally avoid using surrogate models to learn $\\mathcal{X}$ in the context of Counterfactual Explanations.\n\n::: {#prp-surrogate}\n\n## Avoid Surrogates\n\nSince we are in the business of explaining a black-box model, the task of learning realistic representations of the data should not be reallocated from the model itself to some surrogate model.\n\n:::\n\nIn cases where the use of surrogate models cannot be avoided, we propose to weigh the plausibility of counterfactuals against their fidelity to the black-box model. In the context of Explainable AI, fidelity is defined as describing how an explanation approximates the prediction of the black-box model [@molnar2020interpretable]. Fidelity has become the default metric for evaluating Local Model-Agnostic Models, since they often involve local surrogate models whose predictions need not always match those of the black-box model. \n\nIn the case of Counterfactual Explanations, the concept of fidelity has so far been ignored. This is not altogether surprising, since by construction and design, Counterfactual Explanations work with the predictions of the black-box model directly: as stated above, a counterfactual $x^{\\prime}$ is considered valid if and only if $f(x^{\\prime})=t$, where $t$ denote some target outcome. \n\nDoes fidelity even make sense in the context of CE, and if so, how can we define it? In light of the examples in the previous section, we think it is urgent to introduce a notion of fidelity in this context, that relates to the distributional properties of the generated counterfactuals. In particular, we propose that a high-fidelity counterfactual $x^{\\prime}$ complies with the class-conditional distribution $\\mathcal{X}_{\\theta} = p_{\\theta}(X|y)$ where $\\theta$ denote the black-box model parameters. \n\n::: {#def-fidele}\n\n## High-Fidelity Counterfactuals\n\nLet $\\mathcal{X}_{\\theta}|y = p_{\\theta}(X|y)$ denote the class-conditional distribution of $X$ defined by $\\theta$. Then for $x^{\\prime}$ to be considered a high-fidelity counterfactual, we need: $\\mathcal{X}_{\\theta}|t \\approxeq \\mathcal{X}^{\\prime}$ where $t$ denotes the target outcome.\n\n:::\n\nIn order to assess the fidelity of counterfactuals, we propose the following two-step procedure:\n\n1) Generate samples $X_{\\theta}|y$ and $X^{\\prime}$ from $\\mathcal{X}_{\\theta}|t$ and $\\mathcal{X}^{\\prime}$, respectively.\n2) Compute the Maximum Mean Discrepancy (MMD) between $X_{\\theta}|y$ and $X^{\\prime}$. \n\nIf the computed value is different from zero, we can reject the null-hypothesis of fidelity.\n\n> Two challenges here: 1) implementing the sampling procedure in @grathwohl2020your; 2) it is unclear if MMD is really the right way to measure this. \n\n## Conformal Counterfactual Explanations\n\nIn @sec-fidelity, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (ECCCE), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models. \n\n### Minimizing Predictive Uncertainty\n\n@schut2021generating demonstrated that the goal of generating realistic (plausible) counterfactuals can also be achieved by seeking counterfactuals that minimize the predictive uncertainty of the underlying black-box model. Similarly, @antoran2020getting ...\n\n- Problem: restricted to Bayesian models.\n- Solution: post-hoc predictive uncertainty quantification. In particular, Conformal Prediction. \n\n### Background on Conformal Prediction\n\n- Distribution-free, model-agnostic and scalable approach to predictive uncertainty quantification.\n- Conformal prediction is instance-based. So is CE. \n- Take any fitted model and turn it into a conformal model using calibration data.\n- Our approach, therefore, relaxes the restriction on the family of black-box models, at the cost of relying on a subset of the data. Arguably, data is often abundant and in most applications practitioners tend to hold out a test data set anyway. \n\n> Does the coverage guarantee carry over to counterfactuals?\n\n### Generating Conformal Counterfactuals\n\nWhile Conformal Prediction has recently grown in popularity, it does introduce a challenge in the context of classification: the predictions of Conformal Classifiers are set-valued and therefore difficult to work with, since they are, for example, non-differentiable. Fortunately, @stutz2022learning introduced carefully designed differentiable loss functions that make it possible to evaluate the performance of conformal predictions in training. We can leverage these recent advances in the context of gradient-based counterfactual search ...\n\n> Challenge: still need to implement these loss functions. \n\n## Experiments\n\n### Research Questions\n\n- Is CP alone enough to ensure realistic counterfactuals?\n- Do counterfactuals improve further as the models get better?\n- Do counterfactuals get more realistic as coverage\n- What happens as we vary coverage and setsize?\n- What happens as we improve the model robustness?\n- What happens as we improve the model's ability to incorporate predictive uncertainty (deep ensemble, laplace)?\n- What happens if we combine with DiCE, ClaPROAR, Gravitational?\n- What about CE robustness to endogenous shifts [@altmeyer2023endogenous]?\n\n- Benchmarking:\n - add PROBE [@pawelczyk2022probabilistically] into the mix.\n - compare travel costs to domain shits.\n\n> Nice to have: What about using Laplace Approximation, then Conformal Prediction? What about using Conformalised Laplace? \n\n## References\n\n", + "markdown": "---\ntitle: High-Fidelity Counterfactual Explanations through Conformal Prediction\nsubtitle: Research Proposal\nabstract: |\n We propose Conformal Counterfactual Explanations: an effortless and rigorous way to produce realistic and faithful Counterfactual Explanations using Conformal Prediction. To address the need for realistic counterfactuals, existing work has primarily relied on separate generative models to learn the data-generating process. While this is an effective way to produce plausible and model-agnostic counterfactual explanations, it not only introduces a significant engineering overhead but also reallocates the task of creating realistic model explanations from the model itself to the generative model. Recent work has shown that there is no need for any of this when working with probabilistic models that explicitly quantify their own uncertainty. Unfortunately, most models used in practice still do not fulfil that basic requirement, in which case we would like to have a way to quantify predictive uncertainty in a post-hoc fashion.\n---\n\n\n\n## Motivation\n\nCounterfactual Explanations are a powerful, flexible and intuitive way to not only explain black-box models but also enable affected individuals to challenge them through the means of Algorithmic Recourse. \n\n### Counterfactual Explanations or Adversarial Examples?\n\nMost state-of-the-art approaches to generating Counterfactual Explanations (CE) rely on gradient descent in the feature space. The key idea is to perturb inputs $x\\in\\mathcal{X}$ into a black-box model $f: \\mathcal{X} \\mapsto \\mathcal{Y}$ in order to change the model output $f(x)$ to some pre-specified target value $t\\in\\mathcal{Y}$. Formally, this boils down to defining some loss function $\\ell(f(x),t)$ and taking gradient steps in the minimizing direction. The so-generated counterfactuals are considered valid as soon as the predicted label matches the target label. A stripped-down counterfactual explanation is therefore little different from an adversarial example. In @fig-adv, for example, generic counterfactual search as in @wachter2017counterfactual has been applied to MNIST data.\n\n\n\n\n\n{#fig-adv}\n\nThe crucial difference between adversarial examples and counterfactuals is one of intent. While adversarial examples are typically intended to go unnoticed, counterfactuals in the context of Explainable AI are generally sought to be \"plausible\", \"realistic\" or \"feasible\". To fulfil this latter goal, researchers have come up with a myriad of ways. @joshi2019realistic were among the first to suggest that instead of searching counterfactuals in the feature space, we can instead traverse a latent embedding learned by a surrogate generative model. Similarly, @poyiadzi2020face use density ... Finally, @karimi2021algorithmic argues that counterfactuals should comply with the causal model that generates them [CHECK IF WE CAN PHASE THIS LIKE THIS]. Other related approaches include ... All of these different approaches have a common goal: they aim to ensure that the generated counterfactuals comply with the (learned) data-generating process (DGB). \n\n::: {#def-plausible}\n\n## Plausible Counterfactuals\n\nFormally, if $x \\sim \\mathcal{X}$ and for the corresponding counterfactual we have $x^{\\prime}\\sim\\mathcal{X}^{\\prime}$, then for $x^{\\prime}$ to be considered a plausible counterfactual, we need: $\\mathcal{X} \\approxeq \\mathcal{X}^{\\prime}$.\n\n:::\n\nIn the context of Algorithmic Recourse, it makes sense to strive for plausible counterfactuals, since anything else would essentially require individuals to move to out-of-distribution states. But it is worth noting that our ambition to meet this goal, may have implications on our ability to faithfully explain the behaviour of the underlying black-box model (arguably our principal goal). By essentially decoupling the task of learning plausible representations of the data from the model itself, we open ourselves up to vulnerabilities. Using a separate generative model to learn $\\mathcal{X}$, for example, has very serious implications for the generated counterfactuals. @fig-latent compares the results of applying REVISE [@joshi2019realistic] to MNIST data using two different Variational Auto-Encoders: while the counterfactual generated using an expressive (strong) VAE is compelling, the result relying on a less expressive (weak) VAE is not even valid. In this latter case, the decoder step of the VAE fails to yield values in $\\mathcal{X}$ and hence the counterfactual search in the learned latent space is doomed. \n\n{#fig-latent}\n\n> Here it would be nice to have another example where we poison the data going into the generative model to hide biases present in the data (e.g. Boston housing).\n\n- Latent can be manipulated: \n - train biased model\n - train VAE with biased variable removed/attacked (use Boston housing dataset)\n - hypothesis: will generate bias-free explanations\n\n### From Plausible to High-Fidelity Counterfactuals {#sec-fidelity}\n\nIn light of the findings, we propose to generally avoid using surrogate models to learn $\\mathcal{X}$ in the context of Counterfactual Explanations.\n\n::: {#prp-surrogate}\n\n## Avoid Surrogates\n\nSince we are in the business of explaining a black-box model, the task of learning realistic representations of the data should not be reallocated from the model itself to some surrogate model.\n\n:::\n\nIn cases where the use of surrogate models cannot be avoided, we propose to weigh the plausibility of counterfactuals against their fidelity to the black-box model. In the context of Explainable AI, fidelity is defined as describing how an explanation approximates the prediction of the black-box model [@molnar2020interpretable]. Fidelity has become the default metric for evaluating Local Model-Agnostic Models, since they often involve local surrogate models whose predictions need not always match those of the black-box model. \n\nIn the case of Counterfactual Explanations, the concept of fidelity has so far been ignored. This is not altogether surprising, since by construction and design, Counterfactual Explanations work with the predictions of the black-box model directly: as stated above, a counterfactual $x^{\\prime}$ is considered valid if and only if $f(x^{\\prime})=t$, where $t$ denote some target outcome. \n\nDoes fidelity even make sense in the context of CE, and if so, how can we define it? In light of the examples in the previous section, we think it is urgent to introduce a notion of fidelity in this context, that relates to the distributional properties of the generated counterfactuals. In particular, we propose that a high-fidelity counterfactual $x^{\\prime}$ complies with the class-conditional distribution $\\mathcal{X}_{\\theta} = p_{\\theta}(X|y)$ where $\\theta$ denote the black-box model parameters. \n\n::: {#def-fidele}\n\n## High-Fidelity Counterfactuals\n\nLet $\\mathcal{X}_{\\theta}|y = p_{\\theta}(X|y)$ denote the class-conditional distribution of $X$ defined by $\\theta$. Then for $x^{\\prime}$ to be considered a high-fidelity counterfactual, we need: $\\mathcal{X}_{\\theta}|t \\approxeq \\mathcal{X}^{\\prime}$ where $t$ denotes the target outcome.\n\n:::\n\nIn order to assess the fidelity of counterfactuals, we propose the following two-step procedure:\n\n1) Generate samples $X_{\\theta}|y$ and $X^{\\prime}$ from $\\mathcal{X}_{\\theta}|t$ and $\\mathcal{X}^{\\prime}$, respectively.\n2) Compute the Maximum Mean Discrepancy (MMD) between $X_{\\theta}|y$ and $X^{\\prime}$. \n\nIf the computed value is different from zero, we can reject the null-hypothesis of fidelity.\n\n> Two challenges here: 1) implementing the sampling procedure in @grathwohl2020your; 2) it is unclear if MMD is really the right way to measure this. \n\n## Conformal Counterfactual Explanations\n\nIn @sec-fidelity, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (ECCCo), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models. \n\n### Minimizing Predictive Uncertainty\n\n@schut2021generating demonstrated that the goal of generating realistic (plausible) counterfactuals can also be achieved by seeking counterfactuals that minimize the predictive uncertainty of the underlying black-box model. Similarly, @antoran2020getting ...\n\n- Problem: restricted to Bayesian models.\n- Solution: post-hoc predictive uncertainty quantification. In particular, Conformal Prediction. \n\n### Background on Conformal Prediction\n\n- Distribution-free, model-agnostic and scalable approach to predictive uncertainty quantification.\n- Conformal prediction is instance-based. So is CE. \n- Take any fitted model and turn it into a conformal model using calibration data.\n- Our approach, therefore, relaxes the restriction on the family of black-box models, at the cost of relying on a subset of the data. Arguably, data is often abundant and in most applications practitioners tend to hold out a test data set anyway. \n\n> Does the coverage guarantee carry over to counterfactuals?\n\n### Generating Conformal Counterfactuals\n\nWhile Conformal Prediction has recently grown in popularity, it does introduce a challenge in the context of classification: the predictions of Conformal Classifiers are set-valued and therefore difficult to work with, since they are, for example, non-differentiable. Fortunately, @stutz2022learning introduced carefully designed differentiable loss functions that make it possible to evaluate the performance of conformal predictions in training. We can leverage these recent advances in the context of gradient-based counterfactual search ...\n\n> Challenge: still need to implement these loss functions. \n\n## Experiments\n\n### Research Questions\n\n- Is CP alone enough to ensure realistic counterfactuals?\n- Do counterfactuals improve further as the models get better?\n- Do counterfactuals get more realistic as coverage\n- What happens as we vary coverage and setsize?\n- What happens as we improve the model robustness?\n- What happens as we improve the model's ability to incorporate predictive uncertainty (deep ensemble, laplace)?\n- What happens if we combine with DiCE, ClaPROAR, Gravitational?\n- What about CE robustness to endogenous shifts [@altmeyer2023endogenous]?\n\n- Benchmarking:\n - add PROBE [@pawelczyk2022probabilistically] into the mix.\n - compare travel costs to domain shits.\n\n> Nice to have: What about using Laplace Approximation, then Conformal Prediction? What about using Conformalised Laplace? \n\n## References\n\n", "supporting": [ "proposal_files/figure-html" ], diff --git a/_freeze/notebooks/intro/execute-results/html.json b/_freeze/notebooks/intro/execute-results/html.json index 28ce12ee42a881b7ab0755962121c93e8e6fb4bc..f296ca9b508666f8bab035587be2ff18a2adf943 100644 --- a/_freeze/notebooks/intro/execute-results/html.json +++ b/_freeze/notebooks/intro/execute-results/html.json @@ -1,7 +1,7 @@ { "hash": "43d5045964ca39def434cb65914681bc", "result": { - "markdown": "::: {.cell execution_count=1}\n``` {.julia .cell-code}\ninclude(\"notebooks/setup.jl\")\neval(setup_notebooks)\n```\n:::\n\n\n# `ConformalGenerator`\n\nIn this section, we will look at a simple example involving synthetic data, a black-box model and a generic Conformal Counterfactual Generator.\n\n## Black-box Model\n\nWe consider a simple binary classification problem. Let $(X_i, Y_i), \\ i=1,...,n$ denote our feature-label pairs and let $\\mu: \\mathcal{X} \\mapsto \\mathcal{Y}$ denote the mapping from features to labels. For illustration purposes, we will use linearly separable data. \n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\ncounterfactual_data = load_linearly_separable()\n```\n:::\n\n\nWhile we could use a linear classifier in this case, let's pretend we need a black-box model for this task and rely on a small Multi-Layer Perceptron (MLP):\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\nbuilder = MLJFlux.@builder Flux.Chain(\n Dense(n_in, 32, relu),\n Dense(32, n_out)\n)\nclf = NeuralNetworkClassifier(builder=builder, epochs=100)\n```\n:::\n\n\nWe can fit this model to data to produce plug-in predictions. \n\n## Conformal Prediction\n\nHere we will instead use a specific case of CP called *split conformal prediction* which can then be summarized as follows:^[In other places split conformal prediction is sometimes referred to as *inductive* conformal prediction.]\n\n1. Partition the training into a proper training set and a separate calibration set: $\\mathcal{D}_n=\\mathcal{D}^{\\text{train}} \\cup \\mathcal{D}^{\\text{cali}}$.\n2. Train the machine learning model on the proper training set: $\\hat\\mu_{i \\in \\mathcal{D}^{\\text{train}}}(X_i,Y_i)$.\n\nThe model $\\hat\\mu_{i \\in \\mathcal{D}^{\\text{train}}}$ can now produce plug-in predictions. \n\n::: callout-note\n\n## Starting Point\n\nNote that this represents the starting point in applications of Algorithmic Recourse: we have some pre-trained classifier $M$ for which we would like to generate plausible Counterfactual Explanations. Next, we turn to the calibration step. \n:::\n\n3. Compute nonconformity scores, $\\mathcal{S}$, using the calibration data $\\mathcal{D}^{\\text{cali}}$ and the fitted model $\\hat\\mu_{i \\in \\mathcal{D}^{\\text{train}}}$. \n4. For a user-specified desired coverage ratio $(1-\\alpha)$ compute the corresponding quantile, $\\hat{q}$, of the empirical distribution of nonconformity scores, $\\mathcal{S}$.\n5. For the given quantile and test sample $X_{\\text{test}}$, form the corresponding conformal prediction set: \n\n$$\nC(X_{\\text{test}})=\\{y:s(X_{\\text{test}},y) \\le \\hat{q}\\}\n$$ {#eq-set}\n\nThis is the default procedure used for classification and regression in [`ConformalPrediction.jl`](https://github.com/pat-alt/ConformalPrediction.jl). \n\nUsing the package, we can apply Split Conformal Prediction as follows:\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\nX = table(permutedims(counterfactual_data.X))\ny = counterfactual_data.output_encoder.labels\nconf_model = conformal_model(clf; method=:simple_inductive)\nmach = machine(conf_model, X, y)\nfit!(mach)\n```\n:::\n\n\nTo be clear, all of the calibration steps (3 to 5) are post hoc, and yet none of them involved any changes to the model parameters. These are two important characteristics of Split Conformal Prediction (SCP) that make it particularly useful in the context of Algorithmic Recourse. Firstly, the fact that SCP involves posthoc calibration steps that happen after training, ensures that we need not place any restrictions on the black-box model itself. This stands in contrast to the approach proposed by @schut2021generating in which they essentially restrict the class of models to Bayesian models. Secondly, the fact that the model itself is kept entirely intact ensures that the generated counterfactuals maintain fidelity to the model. Finally, note that we also have not resorted to a surrogate model to learn more about $X \\sim \\mathcal{X}$. Instead, we have used the fitted model itself and a calibration data set to learn about the model's predictive uncertainty. \n\n## Differentiable CP\n\nIn order to use CP in the context of gradient-based counterfactual search, we need it to be differentiable. @stutz2022learning introduce a framework for training differentiable conformal predictors. They introduce a configurable loss function as well as smooth set size penalty.\n\n### Smooth Set Size Penalty\n\nStarting with the former, @stutz2022learning propose the following:\n\n$$\n\\Omega(C_{\\theta}(x;\\tau)) = = \\max (0, \\sum_k C_{\\theta,k}(x;\\tau) - \\kappa)\n$$ {#eq-size-loss}\n\nHere, $C_{\\theta,k}(x;\\tau)$ is loosely defined as the probability that class $k$ is assigned to the conformal prediction set $C$. In the context of Conformal Training, this penalty reduces the **inefficiency** of the conformal predictor. \n\nIn our context, we are not interested in improving the model itself, but rather in producing **plausible** counterfactuals. Provided that our counterfactual $x^\\prime$ is already inside the target domain ($\\mathbb{I}_{y^\\prime = t}=1$), penalizing $\\Omega(C_{\\theta}(x;\\tau))$ corresponds to guiding counterfactuals into regions of the target domain that are characterized by low ambiguity: for $\\kappa=1$ the conformal prediction set includes only the target label $t$ as $\\Omega(C_{\\theta}(x;\\tau))$. Arguably, less ambiguous counterfactuals are more **plausible**. Since the search is guided purely by properties of the model itself and (exchangeable) calibration data, counterfactuals also maintain **high fidelity**.\n\nThe left panel of @fig-losses shows the smooth size penalty in the two-dimensional feature space of our synthetic data.\n\n### Configurable Classification Loss\n\nThe right panel of @fig-losses shows the configurable classification loss in the two-dimensional feature space of our synthetic data.\n\n::: {.cell execution_count=5}\n\n::: {.cell-output .cell-output-display execution_count=6}\n{#fig-losses}\n:::\n:::\n\n\n## Fidelity and Plausibility\n\nThe main evaluation criteria we are interested in are *fidelity* and *plausibility*. Interestingly, we could also consider using these measures as penalties in the counterfactual search.\n\n### Fidelity\n\nWe propose to define fidelity as follows:\n\n::: {#def-fidelity}\n\n## High-Fidelity Counterfactuals\n\nLet $\\mathcal{X}_{\\theta}|y = p_{\\theta}(X|y)$ denote the class-conditional distribution of $X$ defined by $\\theta$. Then for $x^{\\prime}$ to be considered a high-fidelity counterfactual, we need: $\\mathcal{X}_{\\theta}|t \\approxeq \\mathcal{X}^{\\prime}$ where $t$ denotes the target outcome.\n\n:::\n\nWe can generate samples from $p_{\\theta}(X|y)$ following @grathwohl2020your. In @fig-energy, I have applied the methodology to our synthetic data.\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\nM = ECCCE.ConformalModel(conf_model, mach.fitresult)\n\nniter = 100\nnsamples = 100\n\nplts = []\nfor (i,target) ∈ enumerate(counterfactual_data.y_levels)\n sampler = ECCCE.EnergySampler(M, counterfactual_data, target; niter=niter, nsamples=100)\n Xgen = rand(sampler, nsamples)\n plt = Plots.plot(M, counterfactual_data; target=target, zoom=-3,cbar=false)\n Plots.scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=i,shape=:star,label=\"X|y=$target\")\n push!(plts, plt)\nend\nPlots.plot(plts..., layout=(1,length(plts)), size=(img_height*length(plts),img_height))\n```\n\n::: {.cell-output .cell-output-display execution_count=7}\n{#fig-energy}\n:::\n:::\n\n\nAs an evaluation metric and penalty, we could use the average distance of the counterfactual $x^{\\prime}$ from these generated samples, for example.\n\n### Plausibility\n\nWe propose to define plausibility as follows:\n\n::: {#def-plausible}\n\n## Plausible Counterfactuals\n\nFormally, let $\\mathcal{X}|t$ denote the conditional distribution of samples in the target class. As before, we have $x^{\\prime}\\sim\\mathcal{X}^{\\prime}$, then for $x^{\\prime}$ to be considered a plausible counterfactual, we need: $\\mathcal{X}|t \\approxeq \\mathcal{X}^{\\prime}$.\n\n:::\n\nAs an evaluation metric and penalty, we could use the average distance of the counterfactual $x^{\\prime}$ from (potentially bootstrapped) training samples in the target class, for example.\n\n## Counterfactual Explanations\n\nNext, let's generate counterfactual explanations for our synthetic data. We first wrap our model in a container that makes it compatible with `CounterfactualExplanations.jl`. Then we draw a random sample, determine its predicted label $\\hat{y}$ and choose the opposite label as our target. \n\n::: {.cell execution_count=7}\n``` {.julia .cell-code}\nx = select_factual(counterfactual_data,rand(1:size(counterfactual_data.X,2)))\ny_factual = predict_label(M, counterfactual_data, x)[1]\ntarget = counterfactual_data.y_levels[counterfactual_data.y_levels .!= y_factual][1]\n```\n:::\n\n\nThe generic Conformal Counterfactual Generator penalises the only the set size only:\n\n$$\nx^\\prime = \\arg \\min_{x^\\prime} \\ell(M(x^\\prime),t) + \\lambda \\mathbb{I}_{y^\\prime = t} \\Omega(C_{\\theta}(x;\\tau)) \n$$ {#eq-solution}\n\n::: {.cell execution_count=8}\n\n::: {.cell-output .cell-output-display execution_count=9}\n{#fig-ce}\n:::\n:::\n\n\n## Multi-Class\n\n::: {.cell execution_count=9}\n``` {.julia .cell-code}\ncounterfactual_data = load_multi_class()\n```\n:::\n\n\n::: {.cell execution_count=10}\n``` {.julia .cell-code}\nX = table(permutedims(counterfactual_data.X))\ny = counterfactual_data.output_encoder.labels\n```\n:::\n\n\n::: {.cell execution_count=11}\n\n::: {.cell-output .cell-output-display execution_count=12}\n{#fig-pen-multi}\n:::\n:::\n\n\n::: {.cell execution_count=12}\n\n::: {.cell-output .cell-output-display execution_count=13}\n{#fig-losses-multi}\n:::\n:::\n\n\n::: {.cell execution_count=13}\n\n::: {.cell-output .cell-output-display execution_count=14}\n{#fig-energy-multi}\n:::\n:::\n\n\n::: {.cell execution_count=14}\n``` {.julia .cell-code}\nx = select_factual(counterfactual_data,rand(1:size(counterfactual_data.X,2)))\ny_factual = predict_label(M, counterfactual_data, x)[1]\ntarget = counterfactual_data.y_levels[counterfactual_data.y_levels .!= y_factual][1]\n```\n:::\n\n\n::: {.cell execution_count=15}\n\n::: {.cell-output .cell-output-display execution_count=16}\n{#fig-ce-multi}\n:::\n:::\n\n\n## Benchmarks\n\n::: {.cell execution_count=16}\n``` {.julia .cell-code}\n# Data:\ndatasets = Dict(\n :linearly_separable => load_linearly_separable(),\n :overlapping => load_overlapping(),\n :moons => load_moons(),\n :circles => load_circles(),\n :multi_class => load_multi_class(),\n)\n\n# Untrained Models:\nmodels = Dict(\n :cov75 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.75)),\n :cov80 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.80)),\n :cov90 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.90)),\n :cov99 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.99)),\n)\n```\n:::\n\n\nThen we can simply loop over the datasets and eventually concatenate the results like so:\n\n::: {.cell execution_count=17}\n``` {.julia .cell-code}\nusing CounterfactualExplanations.Evaluation: benchmark\nbmks = []\nmeasures = [\n CounterfactualExplanations.distance,\n ECCCE.distance_from_energy,\n ECCCE.distance_from_targets\n]\nfor (dataname, dataset) in datasets\n bmk = benchmark(\n dataset; \n models=deepcopy(models), \n generators=generators, \n measure=measures,\n suppress_training=false, dataname=dataname,\n n_individuals=10\n )\n push!(bmks, bmk)\nend\nbmk = reduce(vcat, bmks)\n```\n:::\n\n\n::: {.cell execution_count=18}\n``` {.julia .cell-code}\nf(ce) = CounterfactualExplanations.model_evaluation(ce.M, ce.data)\n@chain bmk() begin\n @group_by(model, generator, dataname, variable)\n @select(model, generator, dataname, ce, value)\n @mutate(performance = f(ce))\n @summarize(model=unique(model), generator=unique(generator), dataname=unique(dataname), performace=unique(performance), value=mean(value))\n @ungroup\n @filter(dataname == :multi_class)\n @filter(model == :cov99)\n @filter(variable == \"distance\")\nend\n```\n:::\n\n\n::: {#fig-benchmark .cell execution_count=19}\n\n::: {.cell-output .cell-output-display}\n{#fig-benchmark-1}\n:::\n\n::: {.cell-output .cell-output-display}\n{#fig-benchmark-2}\n:::\n\n::: {.cell-output .cell-output-display}\n{#fig-benchmark-3}\n:::\n\n::: {.cell-output .cell-output-display}\n{#fig-benchmark-4}\n:::\n\n::: {.cell-output .cell-output-display}\n{#fig-benchmark-5}\n:::\n\nBenchmark results for the different generators.\n:::\n\n\n", + "markdown": "::: {.cell execution_count=1}\n``` {.julia .cell-code}\ninclude(\"notebooks/setup.jl\")\neval(setup_notebooks)\n```\n:::\n\n\n# `ConformalGenerator`\n\nIn this section, we will look at a simple example involving synthetic data, a black-box model and a generic Conformal Counterfactual Generator.\n\n## Black-box Model\n\nWe consider a simple binary classification problem. Let $(X_i, Y_i), \\ i=1,...,n$ denote our feature-label pairs and let $\\mu: \\mathcal{X} \\mapsto \\mathcal{Y}$ denote the mapping from features to labels. For illustration purposes, we will use linearly separable data. \n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\ncounterfactual_data = load_linearly_separable()\n```\n:::\n\n\nWhile we could use a linear classifier in this case, let's pretend we need a black-box model for this task and rely on a small Multi-Layer Perceptron (MLP):\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\nbuilder = MLJFlux.@builder Flux.Chain(\n Dense(n_in, 32, relu),\n Dense(32, n_out)\n)\nclf = NeuralNetworkClassifier(builder=builder, epochs=100)\n```\n:::\n\n\nWe can fit this model to data to produce plug-in predictions. \n\n## Conformal Prediction\n\nHere we will instead use a specific case of CP called *split conformal prediction* which can then be summarized as follows:^[In other places split conformal prediction is sometimes referred to as *inductive* conformal prediction.]\n\n1. Partition the training into a proper training set and a separate calibration set: $\\mathcal{D}_n=\\mathcal{D}^{\\text{train}} \\cup \\mathcal{D}^{\\text{cali}}$.\n2. Train the machine learning model on the proper training set: $\\hat\\mu_{i \\in \\mathcal{D}^{\\text{train}}}(X_i,Y_i)$.\n\nThe model $\\hat\\mu_{i \\in \\mathcal{D}^{\\text{train}}}$ can now produce plug-in predictions. \n\n::: callout-note\n\n## Starting Point\n\nNote that this represents the starting point in applications of Algorithmic Recourse: we have some pre-trained classifier $M$ for which we would like to generate plausible Counterfactual Explanations. Next, we turn to the calibration step. \n:::\n\n3. Compute nonconformity scores, $\\mathcal{S}$, using the calibration data $\\mathcal{D}^{\\text{cali}}$ and the fitted model $\\hat\\mu_{i \\in \\mathcal{D}^{\\text{train}}}$. \n4. For a user-specified desired coverage ratio $(1-\\alpha)$ compute the corresponding quantile, $\\hat{q}$, of the empirical distribution of nonconformity scores, $\\mathcal{S}$.\n5. For the given quantile and test sample $X_{\\text{test}}$, form the corresponding conformal prediction set: \n\n$$\nC(X_{\\text{test}})=\\{y:s(X_{\\text{test}},y) \\le \\hat{q}\\}\n$$ {#eq-set}\n\nThis is the default procedure used for classification and regression in [`ConformalPrediction.jl`](https://github.com/pat-alt/ConformalPrediction.jl). \n\nUsing the package, we can apply Split Conformal Prediction as follows:\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\nX = table(permutedims(counterfactual_data.X))\ny = counterfactual_data.output_encoder.labels\nconf_model = conformal_model(clf; method=:simple_inductive)\nmach = machine(conf_model, X, y)\nfit!(mach)\n```\n:::\n\n\nTo be clear, all of the calibration steps (3 to 5) are post hoc, and yet none of them involved any changes to the model parameters. These are two important characteristics of Split Conformal Prediction (SCP) that make it particularly useful in the context of Algorithmic Recourse. Firstly, the fact that SCP involves posthoc calibration steps that happen after training, ensures that we need not place any restrictions on the black-box model itself. This stands in contrast to the approach proposed by @schut2021generating in which they essentially restrict the class of models to Bayesian models. Secondly, the fact that the model itself is kept entirely intact ensures that the generated counterfactuals maintain fidelity to the model. Finally, note that we also have not resorted to a surrogate model to learn more about $X \\sim \\mathcal{X}$. Instead, we have used the fitted model itself and a calibration data set to learn about the model's predictive uncertainty. \n\n## Differentiable CP\n\nIn order to use CP in the context of gradient-based counterfactual search, we need it to be differentiable. @stutz2022learning introduce a framework for training differentiable conformal predictors. They introduce a configurable loss function as well as smooth set size penalty.\n\n### Smooth Set Size Penalty\n\nStarting with the former, @stutz2022learning propose the following:\n\n$$\n\\Omega(C_{\\theta}(x;\\tau)) = = \\max (0, \\sum_k C_{\\theta,k}(x;\\tau) - \\kappa)\n$$ {#eq-size-loss}\n\nHere, $C_{\\theta,k}(x;\\tau)$ is loosely defined as the probability that class $k$ is assigned to the conformal prediction set $C$. In the context of Conformal Training, this penalty reduces the **inefficiency** of the conformal predictor. \n\nIn our context, we are not interested in improving the model itself, but rather in producing **plausible** counterfactuals. Provided that our counterfactual $x^\\prime$ is already inside the target domain ($\\mathbb{I}_{y^\\prime = t}=1$), penalizing $\\Omega(C_{\\theta}(x;\\tau))$ corresponds to guiding counterfactuals into regions of the target domain that are characterized by low ambiguity: for $\\kappa=1$ the conformal prediction set includes only the target label $t$ as $\\Omega(C_{\\theta}(x;\\tau))$. Arguably, less ambiguous counterfactuals are more **plausible**. Since the search is guided purely by properties of the model itself and (exchangeable) calibration data, counterfactuals also maintain **high fidelity**.\n\nThe left panel of @fig-losses shows the smooth size penalty in the two-dimensional feature space of our synthetic data.\n\n### Configurable Classification Loss\n\nThe right panel of @fig-losses shows the configurable classification loss in the two-dimensional feature space of our synthetic data.\n\n::: {.cell execution_count=5}\n\n::: {.cell-output .cell-output-display execution_count=6}\n{#fig-losses}\n:::\n:::\n\n\n## Fidelity and Plausibility\n\nThe main evaluation criteria we are interested in are *fidelity* and *plausibility*. Interestingly, we could also consider using these measures as penalties in the counterfactual search.\n\n### Fidelity\n\nWe propose to define fidelity as follows:\n\n::: {#def-fidelity}\n\n## High-Fidelity Counterfactuals\n\nLet $\\mathcal{X}_{\\theta}|y = p_{\\theta}(X|y)$ denote the class-conditional distribution of $X$ defined by $\\theta$. Then for $x^{\\prime}$ to be considered a high-fidelity counterfactual, we need: $\\mathcal{X}_{\\theta}|t \\approxeq \\mathcal{X}^{\\prime}$ where $t$ denotes the target outcome.\n\n:::\n\nWe can generate samples from $p_{\\theta}(X|y)$ following @grathwohl2020your. In @fig-energy, I have applied the methodology to our synthetic data.\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\nM = ECCCo.ConformalModel(conf_model, mach.fitresult)\n\nniter = 100\nnsamples = 100\n\nplts = []\nfor (i,target) ∈ enumerate(counterfactual_data.y_levels)\n sampler = ECCCo.EnergySampler(M, counterfactual_data, target; niter=niter, nsamples=100)\n Xgen = rand(sampler, nsamples)\n plt = Plots.plot(M, counterfactual_data; target=target, zoom=-3,cbar=false)\n Plots.scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=i,shape=:star,label=\"X|y=$target\")\n push!(plts, plt)\nend\nPlots.plot(plts..., layout=(1,length(plts)), size=(img_height*length(plts),img_height))\n```\n\n::: {.cell-output .cell-output-display execution_count=7}\n{#fig-energy}\n:::\n:::\n\n\nAs an evaluation metric and penalty, we could use the average distance of the counterfactual $x^{\\prime}$ from these generated samples, for example.\n\n### Plausibility\n\nWe propose to define plausibility as follows:\n\n::: {#def-plausible}\n\n## Plausible Counterfactuals\n\nFormally, let $\\mathcal{X}|t$ denote the conditional distribution of samples in the target class. As before, we have $x^{\\prime}\\sim\\mathcal{X}^{\\prime}$, then for $x^{\\prime}$ to be considered a plausible counterfactual, we need: $\\mathcal{X}|t \\approxeq \\mathcal{X}^{\\prime}$.\n\n:::\n\nAs an evaluation metric and penalty, we could use the average distance of the counterfactual $x^{\\prime}$ from (potentially bootstrapped) training samples in the target class, for example.\n\n## Counterfactual Explanations\n\nNext, let's generate counterfactual explanations for our synthetic data. We first wrap our model in a container that makes it compatible with `CounterfactualExplanations.jl`. Then we draw a random sample, determine its predicted label $\\hat{y}$ and choose the opposite label as our target. \n\n::: {.cell execution_count=7}\n``` {.julia .cell-code}\nx = select_factual(counterfactual_data,rand(1:size(counterfactual_data.X,2)))\ny_factual = predict_label(M, counterfactual_data, x)[1]\ntarget = counterfactual_data.y_levels[counterfactual_data.y_levels .!= y_factual][1]\n```\n:::\n\n\nThe generic Conformal Counterfactual Generator penalises the only the set size only:\n\n$$\nx^\\prime = \\arg \\min_{x^\\prime} \\ell(M(x^\\prime),t) + \\lambda \\mathbb{I}_{y^\\prime = t} \\Omega(C_{\\theta}(x;\\tau)) \n$$ {#eq-solution}\n\n::: {.cell execution_count=8}\n\n::: {.cell-output .cell-output-display execution_count=9}\n{#fig-ce}\n:::\n:::\n\n\n## Multi-Class\n\n::: {.cell execution_count=9}\n``` {.julia .cell-code}\ncounterfactual_data = load_multi_class()\n```\n:::\n\n\n::: {.cell execution_count=10}\n``` {.julia .cell-code}\nX = table(permutedims(counterfactual_data.X))\ny = counterfactual_data.output_encoder.labels\n```\n:::\n\n\n::: {.cell execution_count=11}\n\n::: {.cell-output .cell-output-display execution_count=12}\n{#fig-pen-multi}\n:::\n:::\n\n\n::: {.cell execution_count=12}\n\n::: {.cell-output .cell-output-display execution_count=13}\n{#fig-losses-multi}\n:::\n:::\n\n\n::: {.cell execution_count=13}\n\n::: {.cell-output .cell-output-display execution_count=14}\n{#fig-energy-multi}\n:::\n:::\n\n\n::: {.cell execution_count=14}\n``` {.julia .cell-code}\nx = select_factual(counterfactual_data,rand(1:size(counterfactual_data.X,2)))\ny_factual = predict_label(M, counterfactual_data, x)[1]\ntarget = counterfactual_data.y_levels[counterfactual_data.y_levels .!= y_factual][1]\n```\n:::\n\n\n::: {.cell execution_count=15}\n\n::: {.cell-output .cell-output-display execution_count=16}\n{#fig-ce-multi}\n:::\n:::\n\n\n## Benchmarks\n\n::: {.cell execution_count=16}\n``` {.julia .cell-code}\n# Data:\ndatasets = Dict(\n :linearly_separable => load_linearly_separable(),\n :overlapping => load_overlapping(),\n :moons => load_moons(),\n :circles => load_circles(),\n :multi_class => load_multi_class(),\n)\n\n# Untrained Models:\nmodels = Dict(\n :cov75 => ECCCo.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.75)),\n :cov80 => ECCCo.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.80)),\n :cov90 => ECCCo.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.90)),\n :cov99 => ECCCo.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.99)),\n)\n```\n:::\n\n\nThen we can simply loop over the datasets and eventually concatenate the results like so:\n\n::: {.cell execution_count=17}\n``` {.julia .cell-code}\nusing CounterfactualExplanations.Evaluation: benchmark\nbmks = []\nmeasures = [\n CounterfactualExplanations.distance,\n ECCCo.distance_from_energy,\n ECCCo.distance_from_targets\n]\nfor (dataname, dataset) in datasets\n bmk = benchmark(\n dataset; \n models=deepcopy(models), \n generators=generators, \n measure=measures,\n suppress_training=false, dataname=dataname,\n n_individuals=10\n )\n push!(bmks, bmk)\nend\nbmk = reduce(vcat, bmks)\n```\n:::\n\n\n::: {.cell execution_count=18}\n``` {.julia .cell-code}\nf(ce) = CounterfactualExplanations.model_evaluation(ce.M, ce.data)\n@chain bmk() begin\n @group_by(model, generator, dataname, variable)\n @select(model, generator, dataname, ce, value)\n @mutate(performance = f(ce))\n @summarize(model=unique(model), generator=unique(generator), dataname=unique(dataname), performace=unique(performance), value=mean(value))\n @ungroup\n @filter(dataname == :multi_class)\n @filter(model == :cov99)\n @filter(variable == \"distance\")\nend\n```\n:::\n\n\n::: {#fig-benchmark .cell execution_count=19}\n\n::: {.cell-output .cell-output-display}\n{#fig-benchmark-1}\n:::\n\n::: {.cell-output .cell-output-display}\n{#fig-benchmark-2}\n:::\n\n::: {.cell-output .cell-output-display}\n{#fig-benchmark-3}\n:::\n\n::: {.cell-output .cell-output-display}\n{#fig-benchmark-4}\n:::\n\n::: {.cell-output .cell-output-display}\n{#fig-benchmark-5}\n:::\n\nBenchmark results for the different generators.\n:::\n\n\n", "supporting": [ "intro_files/figure-html" ], diff --git a/_freeze/notebooks/proposal/execute-results/html.json b/_freeze/notebooks/proposal/execute-results/html.json index e9804de6ee50ab04854e7bc338b7099606456d67..87f9d3999a292180d837eb4ea09d9e0afb77f13f 100644 --- a/_freeze/notebooks/proposal/execute-results/html.json +++ b/_freeze/notebooks/proposal/execute-results/html.json @@ -1,7 +1,7 @@ { "hash": "24ab407f04257b00a84f7dcaee456281", "result": { - "markdown": "---\ntitle: High-Fidelity Counterfactual Explanations through Conformal Prediction\nsubtitle: Research Proposal\nabstract: |\n We propose Conformal Counterfactual Explanations: an effortless and rigorous way to produce realistic and faithful Counterfactual Explanations using Conformal Prediction. To address the need for realistic counterfactuals, existing work has primarily relied on separate generative models to learn the data-generating process. While this is an effective way to produce plausible and model-agnostic counterfactual explanations, it not only introduces a significant engineering overhead but also reallocates the task of creating realistic model explanations from the model itself to the generative model. Recent work has shown that there is no need for any of this when working with probabilistic models that explicitly quantify their own uncertainty. Unfortunately, most models used in practice still do not fulfil that basic requirement, in which case we would like to have a way to quantify predictive uncertainty in a post-hoc fashion.\n---\n\n\n\n## Motivation\n\nCounterfactual Explanations are a powerful, flexible and intuitive way to not only explain black-box models but also enable affected individuals to challenge them through the means of Algorithmic Recourse. \n\n### Counterfactual Explanations or Adversarial Examples?\n\nMost state-of-the-art approaches to generating Counterfactual Explanations (CE) rely on gradient descent in the feature space. The key idea is to perturb inputs $x\\in\\mathcal{X}$ into a black-box model $f: \\mathcal{X} \\mapsto \\mathcal{Y}$ in order to change the model output $f(x)$ to some pre-specified target value $t\\in\\mathcal{Y}$. Formally, this boils down to defining some loss function $\\ell(f(x),t)$ and taking gradient steps in the minimizing direction. The so-generated counterfactuals are considered valid as soon as the predicted label matches the target label. A stripped-down counterfactual explanation is therefore little different from an adversarial example. In @fig-adv, for example, generic counterfactual search as in @wachter2017counterfactual has been applied to MNIST data.\n\n\n\n\n\n\n\n{#fig-adv}\n\nThe crucial difference between adversarial examples and counterfactuals is one of intent. While adversarial examples are typically intended to go unnoticed, counterfactuals in the context of Explainable AI are generally sought to be \"plausible\", \"realistic\" or \"feasible\". To fulfil this latter goal, researchers have come up with a myriad of ways. @joshi2019realistic were among the first to suggest that instead of searching counterfactuals in the feature space, we can instead traverse a latent embedding learned by a surrogate generative model. Similarly, @poyiadzi2020face use density ... Finally, @karimi2021algorithmic argues that counterfactuals should comply with the causal model that generates them [CHECK IF WE CAN PHASE THIS LIKE THIS]. Other related approaches include ... All of these different approaches have a common goal: they aim to ensure that the generated counterfactuals comply with the (learned) data-generating process (DGB). \n\n::: {#def-plausible}\n\n## Plausible Counterfactuals\n\nFormally, if $x \\sim \\mathcal{X}$ and for the corresponding counterfactual we have $x^{\\prime}\\sim\\mathcal{X}^{\\prime}$, then for $x^{\\prime}$ to be considered a plausible counterfactual, we need: $\\mathcal{X} \\approxeq \\mathcal{X}^{\\prime}$.\n\n:::\n\nIn the context of Algorithmic Recourse, it makes sense to strive for plausible counterfactuals, since anything else would essentially require individuals to move to out-of-distribution states. But it is worth noting that our ambition to meet this goal, may have implications on our ability to faithfully explain the behaviour of the underlying black-box model (arguably our principal goal). By essentially decoupling the task of learning plausible representations of the data from the model itself, we open ourselves up to vulnerabilities. Using a separate generative model to learn $\\mathcal{X}$, for example, has very serious implications for the generated counterfactuals. @fig-latent compares the results of applying REVISE [@joshi2019realistic] to MNIST data using two different Variational Auto-Encoders: while the counterfactual generated using an expressive (strong) VAE is compelling, the result relying on a less expressive (weak) VAE is not even valid. In this latter case, the decoder step of the VAE fails to yield values in $\\mathcal{X}$ and hence the counterfactual search in the learned latent space is doomed. \n\n\n\n\n\n\n\n{#fig-latent}\n\n> Here it would be nice to have another example where we poison the data going into the generative model to hide biases present in the data (e.g. Boston housing).\n\n- Latent can be manipulated: \n - train biased model\n - train VAE with biased variable removed/attacked (use Boston housing dataset)\n - hypothesis: will generate bias-free explanations\n\n### From Plausible to High-Fidelity Counterfactuals {#sec-fidelity}\n\nIn light of the findings, we propose to generally avoid using surrogate models to learn $\\mathcal{X}$ in the context of Counterfactual Explanations.\n\n::: {#prp-surrogate}\n\n## Avoid Surrogates\n\nSince we are in the business of explaining a black-box model, the task of learning realistic representations of the data should not be reallocated from the model itself to some surrogate model.\n\n:::\n\nIn cases where the use of surrogate models cannot be avoided, we propose to weigh the plausibility of counterfactuals against their fidelity to the black-box model. In the context of Explainable AI, fidelity is defined as describing how an explanation approximates the prediction of the black-box model [@molnar2020interpretable]. Fidelity has become the default metric for evaluating Local Model-Agnostic Models, since they often involve local surrogate models whose predictions need not always match those of the black-box model. \n\nIn the case of Counterfactual Explanations, the concept of fidelity has so far been ignored. This is not altogether surprising, since by construction and design, Counterfactual Explanations work with the predictions of the black-box model directly: as stated above, a counterfactual $x^{\\prime}$ is considered valid if and only if $f(x^{\\prime})=t$, where $t$ denote some target outcome. \n\nDoes fidelity even make sense in the context of CE, and if so, how can we define it? In light of the examples in the previous section, we think it is urgent to introduce a notion of fidelity in this context, that relates to the distributional properties of the generated counterfactuals. In particular, we propose that a high-fidelity counterfactual $x^{\\prime}$ complies with the class-conditional distribution $\\mathcal{X}_{\\theta} = p_{\\theta}(X|y)$ where $\\theta$ denote the black-box model parameters. \n\n::: {#def-fidele}\n\n## High-Fidelity Counterfactuals\n\nLet $\\mathcal{X}_{\\theta}|y = p_{\\theta}(X|y)$ denote the class-conditional distribution of $X$ defined by $\\theta$. Then for $x^{\\prime}$ to be considered a high-fidelity counterfactual, we need: $\\mathcal{X}_{\\theta}|t \\approxeq \\mathcal{X}^{\\prime}$ where $t$ denotes the target outcome.\n\n:::\n\nIn order to assess the fidelity of counterfactuals, we propose the following two-step procedure:\n\n1) Generate samples $X_{\\theta}|y$ and $X^{\\prime}$ from $\\mathcal{X}_{\\theta}|t$ and $\\mathcal{X}^{\\prime}$, respectively.\n2) Compute the Maximum Mean Discrepancy (MMD) between $X_{\\theta}|y$ and $X^{\\prime}$. \n\nIf the computed value is different from zero, we can reject the null-hypothesis of fidelity.\n\n> Two challenges here: 1) implementing the sampling procedure in @grathwohl2020your; 2) it is unclear if MMD is really the right way to measure this. \n\n## Conformal Counterfactual Explanations\n\nIn @sec-fidelity, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (ECCCE), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models. \n\n### Minimizing Predictive Uncertainty\n\n@schut2021generating demonstrated that the goal of generating realistic (plausible) counterfactuals can also be achieved by seeking counterfactuals that minimize the predictive uncertainty of the underlying black-box model. Similarly, @antoran2020getting ...\n\n- Problem: restricted to Bayesian models.\n- Solution: post-hoc predictive uncertainty quantification. In particular, Conformal Prediction. \n\n### Background on Conformal Prediction\n\n- Distribution-free, model-agnostic and scalable approach to predictive uncertainty quantification.\n- Conformal prediction is instance-based. So is CE. \n- Take any fitted model and turn it into a conformal model using calibration data.\n- Our approach, therefore, relaxes the restriction on the family of black-box models, at the cost of relying on a subset of the data. Arguably, data is often abundant and in most applications practitioners tend to hold out a test data set anyway. \n\n> Does the coverage guarantee carry over to counterfactuals?\n\n### Generating Conformal Counterfactuals\n\nWhile Conformal Prediction has recently grown in popularity, it does introduce a challenge in the context of classification: the predictions of Conformal Classifiers are set-valued and therefore difficult to work with, since they are, for example, non-differentiable. Fortunately, @stutz2022learning introduced carefully designed differentiable loss functions that make it possible to evaluate the performance of conformal predictions in training. We can leverage these recent advances in the context of gradient-based counterfactual search ...\n\n> Challenge: still need to implement these loss functions. \n\n## Experiments\n\n### Research Questions\n\n- Is CP alone enough to ensure realistic counterfactuals?\n- Do counterfactuals improve further as the models get better?\n- Do counterfactuals get more realistic as coverage\n- What happens as we vary coverage and setsize?\n- What happens as we improve the model robustness?\n- What happens as we improve the model's ability to incorporate predictive uncertainty (deep ensemble, laplace)?\n- What happens if we combine with DiCE, ClaPROAR, Gravitational?\n- What about CE robustness to endogenous shifts [@altmeyer2023endogenous]?\n\n- Benchmarking:\n - add PROBE [@pawelczyk2022probabilistically] into the mix.\n - compare travel costs to domain shits.\n\n> Nice to have: What about using Laplace Approximation, then Conformal Prediction? What about using Conformalised Laplace? \n\n## References\n\n", + "markdown": "---\ntitle: High-Fidelity Counterfactual Explanations through Conformal Prediction\nsubtitle: Research Proposal\nabstract: |\n We propose Conformal Counterfactual Explanations: an effortless and rigorous way to produce realistic and faithful Counterfactual Explanations using Conformal Prediction. To address the need for realistic counterfactuals, existing work has primarily relied on separate generative models to learn the data-generating process. While this is an effective way to produce plausible and model-agnostic counterfactual explanations, it not only introduces a significant engineering overhead but also reallocates the task of creating realistic model explanations from the model itself to the generative model. Recent work has shown that there is no need for any of this when working with probabilistic models that explicitly quantify their own uncertainty. Unfortunately, most models used in practice still do not fulfil that basic requirement, in which case we would like to have a way to quantify predictive uncertainty in a post-hoc fashion.\n---\n\n\n\n## Motivation\n\nCounterfactual Explanations are a powerful, flexible and intuitive way to not only explain black-box models but also enable affected individuals to challenge them through the means of Algorithmic Recourse. \n\n### Counterfactual Explanations or Adversarial Examples?\n\nMost state-of-the-art approaches to generating Counterfactual Explanations (CE) rely on gradient descent in the feature space. The key idea is to perturb inputs $x\\in\\mathcal{X}$ into a black-box model $f: \\mathcal{X} \\mapsto \\mathcal{Y}$ in order to change the model output $f(x)$ to some pre-specified target value $t\\in\\mathcal{Y}$. Formally, this boils down to defining some loss function $\\ell(f(x),t)$ and taking gradient steps in the minimizing direction. The so-generated counterfactuals are considered valid as soon as the predicted label matches the target label. A stripped-down counterfactual explanation is therefore little different from an adversarial example. In @fig-adv, for example, generic counterfactual search as in @wachter2017counterfactual has been applied to MNIST data.\n\n\n\n\n\n\n\n{#fig-adv}\n\nThe crucial difference between adversarial examples and counterfactuals is one of intent. While adversarial examples are typically intended to go unnoticed, counterfactuals in the context of Explainable AI are generally sought to be \"plausible\", \"realistic\" or \"feasible\". To fulfil this latter goal, researchers have come up with a myriad of ways. @joshi2019realistic were among the first to suggest that instead of searching counterfactuals in the feature space, we can instead traverse a latent embedding learned by a surrogate generative model. Similarly, @poyiadzi2020face use density ... Finally, @karimi2021algorithmic argues that counterfactuals should comply with the causal model that generates them [CHECK IF WE CAN PHASE THIS LIKE THIS]. Other related approaches include ... All of these different approaches have a common goal: they aim to ensure that the generated counterfactuals comply with the (learned) data-generating process (DGB). \n\n::: {#def-plausible}\n\n## Plausible Counterfactuals\n\nFormally, if $x \\sim \\mathcal{X}$ and for the corresponding counterfactual we have $x^{\\prime}\\sim\\mathcal{X}^{\\prime}$, then for $x^{\\prime}$ to be considered a plausible counterfactual, we need: $\\mathcal{X} \\approxeq \\mathcal{X}^{\\prime}$.\n\n:::\n\nIn the context of Algorithmic Recourse, it makes sense to strive for plausible counterfactuals, since anything else would essentially require individuals to move to out-of-distribution states. But it is worth noting that our ambition to meet this goal, may have implications on our ability to faithfully explain the behaviour of the underlying black-box model (arguably our principal goal). By essentially decoupling the task of learning plausible representations of the data from the model itself, we open ourselves up to vulnerabilities. Using a separate generative model to learn $\\mathcal{X}$, for example, has very serious implications for the generated counterfactuals. @fig-latent compares the results of applying REVISE [@joshi2019realistic] to MNIST data using two different Variational Auto-Encoders: while the counterfactual generated using an expressive (strong) VAE is compelling, the result relying on a less expressive (weak) VAE is not even valid. In this latter case, the decoder step of the VAE fails to yield values in $\\mathcal{X}$ and hence the counterfactual search in the learned latent space is doomed. \n\n\n\n\n\n\n\n{#fig-latent}\n\n> Here it would be nice to have another example where we poison the data going into the generative model to hide biases present in the data (e.g. Boston housing).\n\n- Latent can be manipulated: \n - train biased model\n - train VAE with biased variable removed/attacked (use Boston housing dataset)\n - hypothesis: will generate bias-free explanations\n\n### From Plausible to High-Fidelity Counterfactuals {#sec-fidelity}\n\nIn light of the findings, we propose to generally avoid using surrogate models to learn $\\mathcal{X}$ in the context of Counterfactual Explanations.\n\n::: {#prp-surrogate}\n\n## Avoid Surrogates\n\nSince we are in the business of explaining a black-box model, the task of learning realistic representations of the data should not be reallocated from the model itself to some surrogate model.\n\n:::\n\nIn cases where the use of surrogate models cannot be avoided, we propose to weigh the plausibility of counterfactuals against their fidelity to the black-box model. In the context of Explainable AI, fidelity is defined as describing how an explanation approximates the prediction of the black-box model [@molnar2020interpretable]. Fidelity has become the default metric for evaluating Local Model-Agnostic Models, since they often involve local surrogate models whose predictions need not always match those of the black-box model. \n\nIn the case of Counterfactual Explanations, the concept of fidelity has so far been ignored. This is not altogether surprising, since by construction and design, Counterfactual Explanations work with the predictions of the black-box model directly: as stated above, a counterfactual $x^{\\prime}$ is considered valid if and only if $f(x^{\\prime})=t$, where $t$ denote some target outcome. \n\nDoes fidelity even make sense in the context of CE, and if so, how can we define it? In light of the examples in the previous section, we think it is urgent to introduce a notion of fidelity in this context, that relates to the distributional properties of the generated counterfactuals. In particular, we propose that a high-fidelity counterfactual $x^{\\prime}$ complies with the class-conditional distribution $\\mathcal{X}_{\\theta} = p_{\\theta}(X|y)$ where $\\theta$ denote the black-box model parameters. \n\n::: {#def-fidele}\n\n## High-Fidelity Counterfactuals\n\nLet $\\mathcal{X}_{\\theta}|y = p_{\\theta}(X|y)$ denote the class-conditional distribution of $X$ defined by $\\theta$. Then for $x^{\\prime}$ to be considered a high-fidelity counterfactual, we need: $\\mathcal{X}_{\\theta}|t \\approxeq \\mathcal{X}^{\\prime}$ where $t$ denotes the target outcome.\n\n:::\n\nIn order to assess the fidelity of counterfactuals, we propose the following two-step procedure:\n\n1) Generate samples $X_{\\theta}|y$ and $X^{\\prime}$ from $\\mathcal{X}_{\\theta}|t$ and $\\mathcal{X}^{\\prime}$, respectively.\n2) Compute the Maximum Mean Discrepancy (MMD) between $X_{\\theta}|y$ and $X^{\\prime}$. \n\nIf the computed value is different from zero, we can reject the null-hypothesis of fidelity.\n\n> Two challenges here: 1) implementing the sampling procedure in @grathwohl2020your; 2) it is unclear if MMD is really the right way to measure this. \n\n## Conformal Counterfactual Explanations\n\nIn @sec-fidelity, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (ECCCo), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models. \n\n### Minimizing Predictive Uncertainty\n\n@schut2021generating demonstrated that the goal of generating realistic (plausible) counterfactuals can also be achieved by seeking counterfactuals that minimize the predictive uncertainty of the underlying black-box model. Similarly, @antoran2020getting ...\n\n- Problem: restricted to Bayesian models.\n- Solution: post-hoc predictive uncertainty quantification. In particular, Conformal Prediction. \n\n### Background on Conformal Prediction\n\n- Distribution-free, model-agnostic and scalable approach to predictive uncertainty quantification.\n- Conformal prediction is instance-based. So is CE. \n- Take any fitted model and turn it into a conformal model using calibration data.\n- Our approach, therefore, relaxes the restriction on the family of black-box models, at the cost of relying on a subset of the data. Arguably, data is often abundant and in most applications practitioners tend to hold out a test data set anyway. \n\n> Does the coverage guarantee carry over to counterfactuals?\n\n### Generating Conformal Counterfactuals\n\nWhile Conformal Prediction has recently grown in popularity, it does introduce a challenge in the context of classification: the predictions of Conformal Classifiers are set-valued and therefore difficult to work with, since they are, for example, non-differentiable. Fortunately, @stutz2022learning introduced carefully designed differentiable loss functions that make it possible to evaluate the performance of conformal predictions in training. We can leverage these recent advances in the context of gradient-based counterfactual search ...\n\n> Challenge: still need to implement these loss functions. \n\n## Experiments\n\n### Research Questions\n\n- Is CP alone enough to ensure realistic counterfactuals?\n- Do counterfactuals improve further as the models get better?\n- Do counterfactuals get more realistic as coverage\n- What happens as we vary coverage and setsize?\n- What happens as we improve the model robustness?\n- What happens as we improve the model's ability to incorporate predictive uncertainty (deep ensemble, laplace)?\n- What happens if we combine with DiCE, ClaPROAR, Gravitational?\n- What about CE robustness to endogenous shifts [@altmeyer2023endogenous]?\n\n- Benchmarking:\n - add PROBE [@pawelczyk2022probabilistically] into the mix.\n - compare travel costs to domain shits.\n\n> Nice to have: What about using Laplace Approximation, then Conformal Prediction? What about using Conformalised Laplace? \n\n## References\n\n", "supporting": [ "proposal_files/figure-html" ], diff --git a/_freeze/notebooks/synthetic/execute-results/html.json b/_freeze/notebooks/synthetic/execute-results/html.json index 682d0f96f0cf96570eecbd0646392516bcc73e08..273069ba8f8578fbbc11c235c4ee7285dd0f71a4 100644 --- a/_freeze/notebooks/synthetic/execute-results/html.json +++ b/_freeze/notebooks/synthetic/execute-results/html.json @@ -1,7 +1,7 @@ { "hash": "617bb13e20ec081d43c585fd80675156", "result": { - "markdown": "::: {.cell execution_count=1}\n``` {.julia .cell-code}\ninclude(\"notebooks/setup.jl\")\neval(setup_notebooks);\n```\n:::\n\n\n# Synthetic data\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\n# Data:\ndatasets = Dict(\n :linearly_separable => load_linearly_separable(),\n :overlapping => load_overlapping(),\n :moons => load_moons(),\n :circles => load_circles(),\n :multi_class => load_multi_class(),\n)\n\n# Hyperparameters:\ncvgs = [0.5, 0.75, 0.95]\ntemps = [0.01, 0.1, 1.0]\nΛ = [0.0, 0.1, 1.0, 10.0]\nl2_λ = 0.1\n\n# Classifiers:\nepochs = 250\nlink_fun = relu\nlogreg = NeuralNetworkClassifier(builder=MLJFlux.Linear(σ=link_fun), epochs=epochs)\nmlp = NeuralNetworkClassifier(builder=MLJFlux.MLP(hidden=(32,), σ=link_fun), epochs=epochs)\nensmbl = EnsembleModel(model=mlp, n=5)\nclassifiers = Dict(\n # :logreg => logreg,\n :mlp => mlp,\n # :ensmbl => ensmbl,\n)\n\n# Search parameters:\ntarget = 2\nfactual = 1\nmax_iter = 50\ngradient_tol = 1e-2\nopt = Descent(0.01)\n```\n:::\n\n\n\n\n\n\n::: {.cell execution_count=5}\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n:::\n\n\n## Benchmark\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\n# Benchmark generators:\ngenerators = Dict(\n :wachter => GenericGenerator(opt=opt, λ=l2_λ),\n :revise => REVISEGenerator(opt=opt, λ=l2_λ),\n :greedy => GreedyGenerator(),\n)\n\n# Untrained Models:\nmodels = Dict(Symbol(\"cov$(Int(100*cov))\") => ECCCE.ConformalModel(conformal_model(mlp; method=:simple_inductive, coverage=cov)) for cov in cvgs)\n\n# Measures:\nmeasures = [\n CounterfactualExplanations.distance,\n ECCCE.distance_from_energy,\n ECCCE.distance_from_targets,\n CounterfactualExplanations.validity,\n]\n```\n:::\n\n\n### Single CE\n\n\n\n\n\n::: {.cell execution_count=9}\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n:::\n\n\n### Full Benchmark\n\n::: {.cell execution_count=10}\n``` {.julia .cell-code}\nbmks = []\nfor (dataname, dataset) in datasets\n for λ in Λ, temp in temps\n _generators = deepcopy(generators)\n _generators[:cce] = ECCCEGenerator(temp=temp, λ=[l2_λ,λ], opt=opt)\n _generators[:energy] = ECCCE.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt)\n _generators[:target] = ECCCE.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt)\n bmk = benchmark(\n dataset; \n models=deepcopy(models), \n generators=_generators, \n measure=measures,\n suppress_training=false, dataname=dataname,\n n_individuals=5,\n initialization=:identity,\n )\n bmk.evaluation.λ .= λ\n bmk.evaluation.temperature .= temp\n push!(bmks, bmk)\n end\nend\nbmk = reduce(vcat, bmks)\n```\n:::\n\n\n::: {.cell execution_count=11}\n``` {.julia .cell-code}\nCSV.write(joinpath(output_path, \"synthetic_benchmark.csv\"), bmk())\n```\n:::\n\n\n::: {.cell execution_count=12}\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n:::\n\n\n", + "markdown": "::: {.cell execution_count=1}\n``` {.julia .cell-code}\ninclude(\"notebooks/setup.jl\")\neval(setup_notebooks);\n```\n:::\n\n\n# Synthetic data\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\n# Data:\ndatasets = Dict(\n :linearly_separable => load_linearly_separable(),\n :overlapping => load_overlapping(),\n :moons => load_moons(),\n :circles => load_circles(),\n :multi_class => load_multi_class(),\n)\n\n# Hyperparameters:\ncvgs = [0.5, 0.75, 0.95]\ntemps = [0.01, 0.1, 1.0]\nΛ = [0.0, 0.1, 1.0, 10.0]\nl2_λ = 0.1\n\n# Classifiers:\nepochs = 250\nlink_fun = relu\nlogreg = NeuralNetworkClassifier(builder=MLJFlux.Linear(σ=link_fun), epochs=epochs)\nmlp = NeuralNetworkClassifier(builder=MLJFlux.MLP(hidden=(32,), σ=link_fun), epochs=epochs)\nensmbl = EnsembleModel(model=mlp, n=5)\nclassifiers = Dict(\n # :logreg => logreg,\n :mlp => mlp,\n # :ensmbl => ensmbl,\n)\n\n# Search parameters:\ntarget = 2\nfactual = 1\nmax_iter = 50\ngradient_tol = 1e-2\nopt = Descent(0.01)\n```\n:::\n\n\n\n\n\n\n::: {.cell execution_count=5}\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n:::\n\n\n## Benchmark\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\n# Benchmark generators:\ngenerators = Dict(\n :wachter => GenericGenerator(opt=opt, λ=l2_λ),\n :revise => REVISEGenerator(opt=opt, λ=l2_λ),\n :greedy => GreedyGenerator(),\n)\n\n# Untrained Models:\nmodels = Dict(Symbol(\"cov$(Int(100*cov))\") => ECCCo.ConformalModel(conformal_model(mlp; method=:simple_inductive, coverage=cov)) for cov in cvgs)\n\n# Measures:\nmeasures = [\n CounterfactualExplanations.distance,\n ECCCo.distance_from_energy,\n ECCCo.distance_from_targets,\n CounterfactualExplanations.validity,\n]\n```\n:::\n\n\n### Single CE\n\n\n\n\n\n::: {.cell execution_count=9}\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n:::\n\n\n### Full Benchmark\n\n::: {.cell execution_count=10}\n``` {.julia .cell-code}\nbmks = []\nfor (dataname, dataset) in datasets\n for λ in Λ, temp in temps\n _generators = deepcopy(generators)\n _generators[:cce] = ECCCoGenerator(temp=temp, λ=[l2_λ,λ], opt=opt)\n _generators[:energy] = ECCCo.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt)\n _generators[:target] = ECCCo.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt)\n bmk = benchmark(\n dataset; \n models=deepcopy(models), \n generators=_generators, \n measure=measures,\n suppress_training=false, dataname=dataname,\n n_individuals=5,\n initialization=:identity,\n )\n bmk.evaluation.λ .= λ\n bmk.evaluation.temperature .= temp\n push!(bmks, bmk)\n end\nend\nbmk = reduce(vcat, bmks)\n```\n:::\n\n\n::: {.cell execution_count=11}\n``` {.julia .cell-code}\nCSV.write(joinpath(output_path, \"synthetic_benchmark.csv\"), bmk())\n```\n:::\n\n\n::: {.cell execution_count=12}\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n:::\n\n\n", "supporting": [ "synthetic_files" ], diff --git a/docs/notebooks/intro.html b/docs/notebooks/intro.html index 6a00699b32a2ebbece143893998f4e1ecb2e9aa5..99c0a5a3d902bca1966c59c0756568516564ad48 100644 --- a/docs/notebooks/intro.html +++ b/docs/notebooks/intro.html @@ -351,14 +351,14 @@ C(X_{\text{test}})=\{y:s(X_{\text{test}},y) \le \hat{q}\} </div> <p>We can generate samples from <span class="math inline">\(p_{\theta}(X|y)\)</span> following <span class="citation" data-cites="grathwohl2020your">Grathwohl et al. (<a href="references.html#ref-grathwohl2020your" role="doc-biblioref">2020</a>)</span>. In <a href="#fig-energy">Figure <span>2.2</span></a>, I have applied the methodology to our synthetic data.</p> <div class="cell" data-execution_count="6"> -<div class="sourceCode cell-code" id="cb5"><pre class="sourceCode julia code-with-copy"><code class="sourceCode julia"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a>M <span class="op">=</span> ECCCE.<span class="fu">ConformalModel</span>(conf_model, mach.fitresult)</span> +<div class="sourceCode cell-code" id="cb5"><pre class="sourceCode julia code-with-copy"><code class="sourceCode julia"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a>M <span class="op">=</span> ECCCo.<span class="fu">ConformalModel</span>(conf_model, mach.fitresult)</span> <span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a></span> <span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>niter <span class="op">=</span> <span class="fl">100</span></span> <span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>nsamples <span class="op">=</span> <span class="fl">100</span></span> <span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a></span> <span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>plts <span class="op">=</span> []</span> <span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> (i,target) <span class="op">∈</span> <span class="fu">enumerate</span>(counterfactual_data.y_levels)</span> -<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a> sampler <span class="op">=</span> ECCCE.<span class="fu">EnergySampler</span>(M, counterfactual_data, target; niter<span class="op">=</span>niter, nsamples<span class="op">=</span><span class="fl">100</span>)</span> +<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a> sampler <span class="op">=</span> ECCCo.<span class="fu">EnergySampler</span>(M, counterfactual_data, target; niter<span class="op">=</span>niter, nsamples<span class="op">=</span><span class="fl">100</span>)</span> <span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a> Xgen <span class="op">=</span> <span class="fu">rand</span>(sampler, nsamples)</span> <span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a> plt <span class="op">=</span> Plots.<span class="fu">plot</span>(M, counterfactual_data; target<span class="op">=</span>target, zoom<span class="op">=-</span><span class="fl">3</span>,cbar<span class="op">=</span><span class="cn">false</span>)</span> <span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a> Plots.<span class="fu">scatter!</span>(Xgen[<span class="fl">1</span>,<span class="op">:</span>],Xgen[<span class="fl">2</span>,<span class="op">:</span>],alpha<span class="op">=</span><span class="fl">0.5</span>,color<span class="op">=</span>i,shape<span class="op">=:</span>star,label<span class="op">=</span><span class="st">"X|y=</span><span class="sc">$</span>target<span class="st">"</span>)</span> @@ -477,10 +477,10 @@ x^\prime = \arg \min_{x^\prime} \ell(M(x^\prime),t) + \lambda \mathbb{I}_{y^\pr <span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a></span> <span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Untrained Models:</span></span> <span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>models <span class="op">=</span> <span class="fu">Dict</span>(</span> -<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a> <span class="op">:</span>cov75 <span class="op">=></span> ECCCE.<span class="fu">ConformalModel</span>(<span class="fu">conformal_model</span>(clf; method<span class="op">=:</span>simple_inductive, coverage<span class="op">=</span><span class="fl">0.75</span>)),</span> -<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a> <span class="op">:</span>cov80 <span class="op">=></span> ECCCE.<span class="fu">ConformalModel</span>(<span class="fu">conformal_model</span>(clf; method<span class="op">=:</span>simple_inductive, coverage<span class="op">=</span><span class="fl">0.80</span>)),</span> -<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a> <span class="op">:</span>cov90 <span class="op">=></span> ECCCE.<span class="fu">ConformalModel</span>(<span class="fu">conformal_model</span>(clf; method<span class="op">=:</span>simple_inductive, coverage<span class="op">=</span><span class="fl">0.90</span>)),</span> -<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a> <span class="op">:</span>cov99 <span class="op">=></span> ECCCE.<span class="fu">ConformalModel</span>(<span class="fu">conformal_model</span>(clf; method<span class="op">=:</span>simple_inductive, coverage<span class="op">=</span><span class="fl">0.99</span>)),</span> +<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a> <span class="op">:</span>cov75 <span class="op">=></span> ECCCo.<span class="fu">ConformalModel</span>(<span class="fu">conformal_model</span>(clf; method<span class="op">=:</span>simple_inductive, coverage<span class="op">=</span><span class="fl">0.75</span>)),</span> +<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a> <span class="op">:</span>cov80 <span class="op">=></span> ECCCo.<span class="fu">ConformalModel</span>(<span class="fu">conformal_model</span>(clf; method<span class="op">=:</span>simple_inductive, coverage<span class="op">=</span><span class="fl">0.80</span>)),</span> +<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a> <span class="op">:</span>cov90 <span class="op">=></span> ECCCo.<span class="fu">ConformalModel</span>(<span class="fu">conformal_model</span>(clf; method<span class="op">=:</span>simple_inductive, coverage<span class="op">=</span><span class="fl">0.90</span>)),</span> +<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a> <span class="op">:</span>cov99 <span class="op">=></span> ECCCo.<span class="fu">ConformalModel</span>(<span class="fu">conformal_model</span>(clf; method<span class="op">=:</span>simple_inductive, coverage<span class="op">=</span><span class="fl">0.99</span>)),</span> <span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>)</span></code><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></pre></div> </div> <p>Then we can simply loop over the datasets and eventually concatenate the results like so:</p> @@ -489,8 +489,8 @@ x^\prime = \arg \min_{x^\prime} \ell(M(x^\prime),t) + \lambda \mathbb{I}_{y^\pr <span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>bmks <span class="op">=</span> []</span> <span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a>measures <span class="op">=</span> [</span> <span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a> CounterfactualExplanations.distance,</span> -<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a> ECCCE.distance_from_energy,</span> -<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a> ECCCE.distance_from_targets</span> +<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a> ECCCo.distance_from_energy,</span> +<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a> ECCCo.distance_from_targets</span> <span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>]</span> <span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> (dataname, dataset) <span class="kw">in</span> datasets</span> <span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a> bmk <span class="op">=</span> <span class="fu">benchmark</span>(</span> diff --git a/docs/notebooks/proposal.html b/docs/notebooks/proposal.html index 4d65db90bd6789a0b1962f333355b8d6838870cc..d527e1a4097062ced8c1c7720a130fd05caf0174 100644 --- a/docs/notebooks/proposal.html +++ b/docs/notebooks/proposal.html @@ -253,7 +253,7 @@ div.csl-indent { </section> <section id="conformal-counterfactual-explanations" class="level2" data-number="1.2"> <h2 data-number="1.2" class="anchored" data-anchor-id="conformal-counterfactual-explanations"><span class="header-section-number">1.2</span> Conformal Counterfactual Explanations</h2> -<p>In <a href="#sec-fidelity"><span>Section 1.1.2</span></a>, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (ECCCE), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models.</p> +<p>In <a href="#sec-fidelity"><span>Section 1.1.2</span></a>, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (ECCCo), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models.</p> <section id="minimizing-predictive-uncertainty" class="level3" data-number="1.2.1"> <h3 data-number="1.2.1" class="anchored" data-anchor-id="minimizing-predictive-uncertainty"><span class="header-section-number">1.2.1</span> Minimizing Predictive Uncertainty</h3> <p><span class="citation" data-cites="schut2021generating">Schut et al. (<a href="references.html#ref-schut2021generating" role="doc-biblioref">2021</a>)</span> demonstrated that the goal of generating realistic (plausible) counterfactuals can also be achieved by seeking counterfactuals that minimize the predictive uncertainty of the underlying black-box model. Similarly, <span class="citation" data-cites="antoran2020getting">Antorán et al. (<a href="references.html#ref-antoran2020getting" role="doc-biblioref">2020</a>)</span> …</p> diff --git a/docs/notebooks/synthetic.html b/docs/notebooks/synthetic.html index ec5ffcc29748f3be384d57a126c744d09b357a8b..c058f1e5717efe64720a59ac070d086284a8b641 100644 --- a/docs/notebooks/synthetic.html +++ b/docs/notebooks/synthetic.html @@ -330,13 +330,13 @@ code span.wa { color: #60a0b0; font-weight: bold; font-style: italic; } /* Warni <span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>)</span> <span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a></span> <span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Untrained Models:</span></span> -<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>models <span class="op">=</span> <span class="fu">Dict</span>(<span class="fu">Symbol</span>(<span class="st">"cov</span><span class="sc">$</span>(<span class="fu">Int</span>(<span class="fl">100</span><span class="op">*</span>cov))<span class="st">"</span>) <span class="op">=></span> ECCCE.<span class="fu">ConformalModel</span>(<span class="fu">conformal_model</span>(mlp; method<span class="op">=:</span>simple_inductive, coverage<span class="op">=</span>cov)) <span class="cf">for</span> cov <span class="kw">in</span> cvgs)</span> +<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>models <span class="op">=</span> <span class="fu">Dict</span>(<span class="fu">Symbol</span>(<span class="st">"cov</span><span class="sc">$</span>(<span class="fu">Int</span>(<span class="fl">100</span><span class="op">*</span>cov))<span class="st">"</span>) <span class="op">=></span> ECCCo.<span class="fu">ConformalModel</span>(<span class="fu">conformal_model</span>(mlp; method<span class="op">=:</span>simple_inductive, coverage<span class="op">=</span>cov)) <span class="cf">for</span> cov <span class="kw">in</span> cvgs)</span> <span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a></span> <span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Measures:</span></span> <span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>measures <span class="op">=</span> [</span> <span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a> CounterfactualExplanations.distance,</span> -<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a> ECCCE.distance_from_energy,</span> -<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a> ECCCE.distance_from_targets,</span> +<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a> ECCCo.distance_from_energy,</span> +<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a> ECCCo.distance_from_targets,</span> <span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a> CounterfactualExplanations.validity,</span> <span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>]</span></code><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></pre></div> </div> @@ -397,9 +397,9 @@ code span.wa { color: #60a0b0; font-weight: bold; font-style: italic; } /* Warni <span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> (dataname, dataset) <span class="kw">in</span> datasets</span> <span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> λ <span class="kw">in</span> Λ, temp <span class="kw">in</span> temps</span> <span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a> _generators <span class="op">=</span> <span class="fu">deepcopy</span>(generators)</span> -<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a> _generators[<span class="op">:</span>cce] <span class="op">=</span> <span class="fu">ECCCEGenerator</span>(temp<span class="op">=</span>temp, λ<span class="op">=</span>[l2_λ,λ], opt<span class="op">=</span>opt)</span> -<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a> _generators[<span class="op">:</span>energy] <span class="op">=</span> ECCCE.<span class="fu">EnergyDrivenGenerator</span>(λ<span class="op">=</span>[l2_λ,λ], opt<span class="op">=</span>opt)</span> -<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a> _generators[<span class="op">:</span>target] <span class="op">=</span> ECCCE.<span class="fu">TargetDrivenGenerator</span>(λ<span class="op">=</span>[l2_λ,λ], opt<span class="op">=</span>opt)</span> +<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a> _generators[<span class="op">:</span>cce] <span class="op">=</span> <span class="fu">ECCCoGenerator</span>(temp<span class="op">=</span>temp, λ<span class="op">=</span>[l2_λ,λ], opt<span class="op">=</span>opt)</span> +<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a> _generators[<span class="op">:</span>energy] <span class="op">=</span> ECCCo.<span class="fu">EnergyDrivenGenerator</span>(λ<span class="op">=</span>[l2_λ,λ], opt<span class="op">=</span>opt)</span> +<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a> _generators[<span class="op">:</span>target] <span class="op">=</span> ECCCo.<span class="fu">TargetDrivenGenerator</span>(λ<span class="op">=</span>[l2_λ,λ], opt<span class="op">=</span>opt)</span> <span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a> bmk <span class="op">=</span> <span class="fu">benchmark</span>(</span> <span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a> dataset; </span> <span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a> models<span class="op">=</span><span class="fu">deepcopy</span>(models), </span> diff --git a/docs/search.json b/docs/search.json index c097c74a608f9bbb42131f8a9e2c671c696bc0cc..bf32aeaf8e1518e44764f68e3341777feacf83f3 100644 --- a/docs/search.json +++ b/docs/search.json @@ -18,7 +18,7 @@ "href": "notebooks/proposal.html#conformal-counterfactual-explanations", "title": "1 High-Fidelity Counterfactual Explanations through Conformal Prediction", "section": "1.2 Conformal Counterfactual Explanations", - "text": "1.2 Conformal Counterfactual Explanations\nIn Section 1.1.2, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (ECCCE), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models.\n\n1.2.1 Minimizing Predictive Uncertainty\nSchut et al. (2021) demonstrated that the goal of generating realistic (plausible) counterfactuals can also be achieved by seeking counterfactuals that minimize the predictive uncertainty of the underlying black-box model. Similarly, Antorán et al. (2020) …\n\nProblem: restricted to Bayesian models.\nSolution: post-hoc predictive uncertainty quantification. In particular, Conformal Prediction.\n\n\n\n1.2.2 Background on Conformal Prediction\n\nDistribution-free, model-agnostic and scalable approach to predictive uncertainty quantification.\nConformal prediction is instance-based. So is CE.\nTake any fitted model and turn it into a conformal model using calibration data.\nOur approach, therefore, relaxes the restriction on the family of black-box models, at the cost of relying on a subset of the data. Arguably, data is often abundant and in most applications practitioners tend to hold out a test data set anyway.\n\n\nDoes the coverage guarantee carry over to counterfactuals?\n\n\n\n1.2.3 Generating Conformal Counterfactuals\nWhile Conformal Prediction has recently grown in popularity, it does introduce a challenge in the context of classification: the predictions of Conformal Classifiers are set-valued and therefore difficult to work with, since they are, for example, non-differentiable. Fortunately, Stutz et al. (2022) introduced carefully designed differentiable loss functions that make it possible to evaluate the performance of conformal predictions in training. We can leverage these recent advances in the context of gradient-based counterfactual search …\n\nChallenge: still need to implement these loss functions." + "text": "1.2 Conformal Counterfactual Explanations\nIn Section 1.1.2, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (ECCCo), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models.\n\n1.2.1 Minimizing Predictive Uncertainty\nSchut et al. (2021) demonstrated that the goal of generating realistic (plausible) counterfactuals can also be achieved by seeking counterfactuals that minimize the predictive uncertainty of the underlying black-box model. Similarly, Antorán et al. (2020) …\n\nProblem: restricted to Bayesian models.\nSolution: post-hoc predictive uncertainty quantification. In particular, Conformal Prediction.\n\n\n\n1.2.2 Background on Conformal Prediction\n\nDistribution-free, model-agnostic and scalable approach to predictive uncertainty quantification.\nConformal prediction is instance-based. So is CE.\nTake any fitted model and turn it into a conformal model using calibration data.\nOur approach, therefore, relaxes the restriction on the family of black-box models, at the cost of relying on a subset of the data. Arguably, data is often abundant and in most applications practitioners tend to hold out a test data set anyway.\n\n\nDoes the coverage guarantee carry over to counterfactuals?\n\n\n\n1.2.3 Generating Conformal Counterfactuals\nWhile Conformal Prediction has recently grown in popularity, it does introduce a challenge in the context of classification: the predictions of Conformal Classifiers are set-valued and therefore difficult to work with, since they are, for example, non-differentiable. Fortunately, Stutz et al. (2022) introduced carefully designed differentiable loss functions that make it possible to evaluate the performance of conformal predictions in training. We can leverage these recent advances in the context of gradient-based counterfactual search …\n\nChallenge: still need to implement these loss functions." }, { "objectID": "notebooks/proposal.html#experiments", @@ -60,7 +60,7 @@ "href": "notebooks/intro.html#fidelity-and-plausibility", "title": "2 ConformalGenerator", "section": "2.4 Fidelity and Plausibility", - "text": "2.4 Fidelity and Plausibility\nThe main evaluation criteria we are interested in are fidelity and plausibility. Interestingly, we could also consider using these measures as penalties in the counterfactual search.\n\n2.4.1 Fidelity\nWe propose to define fidelity as follows:\n\nDefinition 2.1 (High-Fidelity Counterfactuals) Let \\(\\mathcal{X}_{\\theta}|y = p_{\\theta}(X|y)\\) denote the class-conditional distribution of \\(X\\) defined by \\(\\theta\\). Then for \\(x^{\\prime}\\) to be considered a high-fidelity counterfactual, we need: \\(\\mathcal{X}_{\\theta}|t \\approxeq \\mathcal{X}^{\\prime}\\) where \\(t\\) denotes the target outcome.\n\nWe can generate samples from \\(p_{\\theta}(X|y)\\) following Grathwohl et al. (2020). In Figure 2.2, I have applied the methodology to our synthetic data.\n\nM = ECCCE.ConformalModel(conf_model, mach.fitresult)\n\nniter = 100\nnsamples = 100\n\nplts = []\nfor (i,target) ∈ enumerate(counterfactual_data.y_levels)\n sampler = ECCCE.EnergySampler(M, counterfactual_data, target; niter=niter, nsamples=100)\n Xgen = rand(sampler, nsamples)\n plt = Plots.plot(M, counterfactual_data; target=target, zoom=-3,cbar=false)\n Plots.scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=i,shape=:star,label=\"X|y=$target\")\n push!(plts, plt)\nend\nPlots.plot(plts..., layout=(1,length(plts)), size=(img_height*length(plts),img_height))\n\n\n\n\nFigure 2.2: Energy-based conditional samples.\n\n\n\n\nAs an evaluation metric and penalty, we could use the average distance of the counterfactual \\(x^{\\prime}\\) from these generated samples, for example.\n\n\n2.4.2 Plausibility\nWe propose to define plausibility as follows:\n\nDefinition 2.2 (Plausible Counterfactuals) Formally, let \\(\\mathcal{X}|t\\) denote the conditional distribution of samples in the target class. As before, we have \\(x^{\\prime}\\sim\\mathcal{X}^{\\prime}\\), then for \\(x^{\\prime}\\) to be considered a plausible counterfactual, we need: \\(\\mathcal{X}|t \\approxeq \\mathcal{X}^{\\prime}\\).\n\nAs an evaluation metric and penalty, we could use the average distance of the counterfactual \\(x^{\\prime}\\) from (potentially bootstrapped) training samples in the target class, for example." + "text": "2.4 Fidelity and Plausibility\nThe main evaluation criteria we are interested in are fidelity and plausibility. Interestingly, we could also consider using these measures as penalties in the counterfactual search.\n\n2.4.1 Fidelity\nWe propose to define fidelity as follows:\n\nDefinition 2.1 (High-Fidelity Counterfactuals) Let \\(\\mathcal{X}_{\\theta}|y = p_{\\theta}(X|y)\\) denote the class-conditional distribution of \\(X\\) defined by \\(\\theta\\). Then for \\(x^{\\prime}\\) to be considered a high-fidelity counterfactual, we need: \\(\\mathcal{X}_{\\theta}|t \\approxeq \\mathcal{X}^{\\prime}\\) where \\(t\\) denotes the target outcome.\n\nWe can generate samples from \\(p_{\\theta}(X|y)\\) following Grathwohl et al. (2020). In Figure 2.2, I have applied the methodology to our synthetic data.\n\nM = ECCCo.ConformalModel(conf_model, mach.fitresult)\n\nniter = 100\nnsamples = 100\n\nplts = []\nfor (i,target) ∈ enumerate(counterfactual_data.y_levels)\n sampler = ECCCo.EnergySampler(M, counterfactual_data, target; niter=niter, nsamples=100)\n Xgen = rand(sampler, nsamples)\n plt = Plots.plot(M, counterfactual_data; target=target, zoom=-3,cbar=false)\n Plots.scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=i,shape=:star,label=\"X|y=$target\")\n push!(plts, plt)\nend\nPlots.plot(plts..., layout=(1,length(plts)), size=(img_height*length(plts),img_height))\n\n\n\n\nFigure 2.2: Energy-based conditional samples.\n\n\n\n\nAs an evaluation metric and penalty, we could use the average distance of the counterfactual \\(x^{\\prime}\\) from these generated samples, for example.\n\n\n2.4.2 Plausibility\nWe propose to define plausibility as follows:\n\nDefinition 2.2 (Plausible Counterfactuals) Formally, let \\(\\mathcal{X}|t\\) denote the conditional distribution of samples in the target class. As before, we have \\(x^{\\prime}\\sim\\mathcal{X}^{\\prime}\\), then for \\(x^{\\prime}\\) to be considered a plausible counterfactual, we need: \\(\\mathcal{X}|t \\approxeq \\mathcal{X}^{\\prime}\\).\n\nAs an evaluation metric and penalty, we could use the average distance of the counterfactual \\(x^{\\prime}\\) from (potentially bootstrapped) training samples in the target class, for example." }, { "objectID": "notebooks/intro.html#counterfactual-explanations", @@ -81,14 +81,14 @@ "href": "notebooks/intro.html#benchmarks", "title": "2 ConformalGenerator", "section": "2.7 Benchmarks", - "text": "2.7 Benchmarks\n\n# Data:\ndatasets = Dict(\n :linearly_separable => load_linearly_separable(),\n :overlapping => load_overlapping(),\n :moons => load_moons(),\n :circles => load_circles(),\n :multi_class => load_multi_class(),\n)\n\n# Untrained Models:\nmodels = Dict(\n :cov75 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.75)),\n :cov80 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.80)),\n :cov90 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.90)),\n :cov99 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.99)),\n)\n\nThen we can simply loop over the datasets and eventually concatenate the results like so:\n\nusing CounterfactualExplanations.Evaluation: benchmark\nbmks = []\nmeasures = [\n CounterfactualExplanations.distance,\n ECCCE.distance_from_energy,\n ECCCE.distance_from_targets\n]\nfor (dataname, dataset) in datasets\n bmk = benchmark(\n dataset; \n models=deepcopy(models), \n generators=generators, \n measure=measures,\n suppress_training=false, dataname=dataname,\n n_individuals=10\n )\n push!(bmks, bmk)\nend\nbmk = reduce(vcat, bmks)\n\n\nf(ce) = CounterfactualExplanations.model_evaluation(ce.M, ce.data)\n@chain bmk() begin\n @group_by(model, generator, dataname, variable)\n @select(model, generator, dataname, ce, value)\n @mutate(performance = f(ce))\n @summarize(model=unique(model), generator=unique(generator), dataname=unique(dataname), performace=unique(performance), value=mean(value))\n @ungroup\n @filter(dataname == :multi_class)\n @filter(model == :cov99)\n @filter(variable == \"distance\")\nend\n\n\n\n\n\n\n\n(a) Circles.\n\n\n\n\n\n\n\n(b) Linearly Separable.\n\n\n\n\n\n\n\n(c) Moons.\n\n\n\n\n\n\n\n(d) Multi-class.\n\n\n\n\n\n\n\n(e) Overlapping.\n\n\n\nFigure 2.8: Benchmark results for the different generators.\n\n\n\n\n\n\nGrathwohl, Will, Kuan-Chieh Wang, Joern-Henrik Jacobsen, David Duvenaud, Mohammad Norouzi, and Kevin Swersky. 2020. “Your Classifier Is Secretly an Energy Based Model and You Should Treat It Like One.†In. https://openreview.net/forum?id=Hkxzx0NtDB.\n\n\nSchut, Lisa, Oscar Key, Rory Mc Grath, Luca Costabello, Bogdan Sacaleanu, Yarin Gal, et al. 2021. “Generating Interpretable Counterfactual Explanations By Implicit Minimisation of Epistemic and Aleatoric Uncertainties.†In International Conference on Artificial Intelligence and Statistics, 1756–64. PMLR.\n\n\nStutz, David, Krishnamurthy Dj Dvijotham, Ali Taylan Cemgil, and Arnaud Doucet. 2022. “Learning Optimal Conformal Classifiers.†In. https://openreview.net/forum?id=t8O-4LKFVx." + "text": "2.7 Benchmarks\n\n# Data:\ndatasets = Dict(\n :linearly_separable => load_linearly_separable(),\n :overlapping => load_overlapping(),\n :moons => load_moons(),\n :circles => load_circles(),\n :multi_class => load_multi_class(),\n)\n\n# Untrained Models:\nmodels = Dict(\n :cov75 => ECCCo.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.75)),\n :cov80 => ECCCo.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.80)),\n :cov90 => ECCCo.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.90)),\n :cov99 => ECCCo.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.99)),\n)\n\nThen we can simply loop over the datasets and eventually concatenate the results like so:\n\nusing CounterfactualExplanations.Evaluation: benchmark\nbmks = []\nmeasures = [\n CounterfactualExplanations.distance,\n ECCCo.distance_from_energy,\n ECCCo.distance_from_targets\n]\nfor (dataname, dataset) in datasets\n bmk = benchmark(\n dataset; \n models=deepcopy(models), \n generators=generators, \n measure=measures,\n suppress_training=false, dataname=dataname,\n n_individuals=10\n )\n push!(bmks, bmk)\nend\nbmk = reduce(vcat, bmks)\n\n\nf(ce) = CounterfactualExplanations.model_evaluation(ce.M, ce.data)\n@chain bmk() begin\n @group_by(model, generator, dataname, variable)\n @select(model, generator, dataname, ce, value)\n @mutate(performance = f(ce))\n @summarize(model=unique(model), generator=unique(generator), dataname=unique(dataname), performace=unique(performance), value=mean(value))\n @ungroup\n @filter(dataname == :multi_class)\n @filter(model == :cov99)\n @filter(variable == \"distance\")\nend\n\n\n\n\n\n\n\n(a) Circles.\n\n\n\n\n\n\n\n(b) Linearly Separable.\n\n\n\n\n\n\n\n(c) Moons.\n\n\n\n\n\n\n\n(d) Multi-class.\n\n\n\n\n\n\n\n(e) Overlapping.\n\n\n\nFigure 2.8: Benchmark results for the different generators.\n\n\n\n\n\n\nGrathwohl, Will, Kuan-Chieh Wang, Joern-Henrik Jacobsen, David Duvenaud, Mohammad Norouzi, and Kevin Swersky. 2020. “Your Classifier Is Secretly an Energy Based Model and You Should Treat It Like One.†In. https://openreview.net/forum?id=Hkxzx0NtDB.\n\n\nSchut, Lisa, Oscar Key, Rory Mc Grath, Luca Costabello, Bogdan Sacaleanu, Yarin Gal, et al. 2021. “Generating Interpretable Counterfactual Explanations By Implicit Minimisation of Epistemic and Aleatoric Uncertainties.†In International Conference on Artificial Intelligence and Statistics, 1756–64. PMLR.\n\n\nStutz, David, Krishnamurthy Dj Dvijotham, Ali Taylan Cemgil, and Arnaud Doucet. 2022. “Learning Optimal Conformal Classifiers.†In. https://openreview.net/forum?id=t8O-4LKFVx." }, { "objectID": "notebooks/synthetic.html#benchmark", "href": "notebooks/synthetic.html#benchmark", "title": "3 Synthetic data", "section": "3.1 Benchmark", - "text": "3.1 Benchmark\n\n# Benchmark generators:\ngenerators = Dict(\n :wachter => GenericGenerator(opt=opt, λ=l2_λ),\n :revise => REVISEGenerator(opt=opt, λ=l2_λ),\n :greedy => GreedyGenerator(),\n)\n\n# Untrained Models:\nmodels = Dict(Symbol(\"cov$(Int(100*cov))\") => ECCCE.ConformalModel(conformal_model(mlp; method=:simple_inductive, coverage=cov)) for cov in cvgs)\n\n# Measures:\nmeasures = [\n CounterfactualExplanations.distance,\n ECCCE.distance_from_energy,\n ECCCE.distance_from_targets,\n CounterfactualExplanations.validity,\n]\n\n\n3.1.1 Single CE\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n3.1.2 Full Benchmark\n\nbmks = []\nfor (dataname, dataset) in datasets\n for λ in Λ, temp in temps\n _generators = deepcopy(generators)\n _generators[:cce] = ECCCEGenerator(temp=temp, λ=[l2_λ,λ], opt=opt)\n _generators[:energy] = ECCCE.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt)\n _generators[:target] = ECCCE.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt)\n bmk = benchmark(\n dataset; \n models=deepcopy(models), \n generators=_generators, \n measure=measures,\n suppress_training=false, dataname=dataname,\n n_individuals=5,\n initialization=:identity,\n )\n bmk.evaluation.λ .= λ\n bmk.evaluation.temperature .= temp\n push!(bmks, bmk)\n end\nend\nbmk = reduce(vcat, bmks)\n\n\nCSV.write(joinpath(output_path, \"synthetic_benchmark.csv\"), bmk())" + "text": "3.1 Benchmark\n\n# Benchmark generators:\ngenerators = Dict(\n :wachter => GenericGenerator(opt=opt, λ=l2_λ),\n :revise => REVISEGenerator(opt=opt, λ=l2_λ),\n :greedy => GreedyGenerator(),\n)\n\n# Untrained Models:\nmodels = Dict(Symbol(\"cov$(Int(100*cov))\") => ECCCo.ConformalModel(conformal_model(mlp; method=:simple_inductive, coverage=cov)) for cov in cvgs)\n\n# Measures:\nmeasures = [\n CounterfactualExplanations.distance,\n ECCCo.distance_from_energy,\n ECCCo.distance_from_targets,\n CounterfactualExplanations.validity,\n]\n\n\n3.1.1 Single CE\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n3.1.2 Full Benchmark\n\nbmks = []\nfor (dataname, dataset) in datasets\n for λ in Λ, temp in temps\n _generators = deepcopy(generators)\n _generators[:cce] = ECCCoGenerator(temp=temp, λ=[l2_λ,λ], opt=opt)\n _generators[:energy] = ECCCo.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt)\n _generators[:target] = ECCCo.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt)\n bmk = benchmark(\n dataset; \n models=deepcopy(models), \n generators=_generators, \n measure=measures,\n suppress_training=false, dataname=dataname,\n n_individuals=5,\n initialization=:identity,\n )\n bmk.evaluation.λ .= λ\n bmk.evaluation.temperature .= temp\n push!(bmks, bmk)\n end\nend\nbmk = reduce(vcat, bmks)\n\n\nCSV.write(joinpath(output_path, \"synthetic_benchmark.csv\"), bmk())" }, { "objectID": "notebooks/references.html", diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml index d5d8d17a7a0a789c1182b3436ae25b9c0f895565..6dbe102ca9eb0869445e2279efae73b9f686cf9d 100644 --- a/notebooks/Manifest.toml +++ b/notebooks/Manifest.toml @@ -496,7 +496,7 @@ git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" version = "0.6.8" -[[deps.ECCCE]] +[[deps.ECCCo]] deps = ["CategoricalArrays", "ChainRules", "ConformalPrediction", "CounterfactualExplanations", "Distances", "Distributions", "Flux", "JointEnergyModels", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "Parameters", "PkgTemplates", "Plots", "Random", "SliceMap", "Statistics", "StatsBase", "StatsPlots", "Term"] path = ".." uuid = "0232c203-4013-4b0d-ad96-43e3e11ac3bf" diff --git a/notebooks/Project.toml b/notebooks/Project.toml index 5e61dae429cc465ff36306b53f150ec683261131..9d63873a5c251704aed1ab60c1256b1f9c4c8f71 100644 --- a/notebooks/Project.toml +++ b/notebooks/Project.toml @@ -1,6 +1,6 @@ [deps] AlgebraOfGraphics = "cbdf2221-f076-402e-a563-3d30da359d67" -ECCCE = "0232c203-4013-4b0d-ad96-43e3e11ac3bf" +ECCCo = "0232c203-4013-4b0d-ad96-43e3e11ac3bf" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e" diff --git a/notebooks/intro.qmd b/notebooks/intro.qmd index 1c2ad2720767c6c5a589aa18f191edf1d4756330..b146559fde87570e3143cbc1bd9060faa50e4b03 100644 --- a/notebooks/intro.qmd +++ b/notebooks/intro.qmd @@ -122,14 +122,14 @@ We can generate samples from $p_{\theta}(X|y)$ following @grathwohl2020your. In #| label: fig-energy #| output: true -M = ECCCE.ConformalModel(conf_model, mach.fitresult) +M = ECCCo.ConformalModel(conf_model, mach.fitresult) niter = 100 nsamples = 100 plts = [] for (i,target) ∈ enumerate(counterfactual_data.y_levels) - sampler = ECCCE.EnergySampler(M, counterfactual_data, target; niter=niter, nsamples=100) + sampler = ECCCo.EnergySampler(M, counterfactual_data, target; niter=niter, nsamples=100) Xgen = rand(sampler, nsamples) plt = Plots.plot(M, counterfactual_data; target=target, zoom=-3,cbar=false) Plots.scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=i,shape=:star,label="X|y=$target") @@ -294,11 +294,11 @@ Plots.plot(plts..., layout=(length(cvgs),length(cvgs)), size=(2img_height*length niter = 100 nsamples = 100 -M = ECCCE.ConformalModel(conf_model, mach.fitresult; likelihood=:classification_multi) +M = ECCCo.ConformalModel(conf_model, mach.fitresult; likelihood=:classification_multi) plts = [] for target ∈ counterfactual_data.y_levels - sampler = ECCCE.EnergySampler(M, counterfactual_data, target; niter=niter, nsamples=100) + sampler = ECCCo.EnergySampler(M, counterfactual_data, target; niter=niter, nsamples=100) Xgen = rand(sampler, nsamples) plt = Plots.plot(M, counterfactual_data; target=target, zoom=-0.5,cbar=false) Plots.scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=target,shape=:star,label="X|y=$target") @@ -354,10 +354,10 @@ datasets = Dict( # Untrained Models: models = Dict( - :cov75 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.75)), - :cov80 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.80)), - :cov90 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.90)), - :cov99 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.99)), + :cov75 => ECCCo.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.75)), + :cov80 => ECCCo.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.80)), + :cov90 => ECCCo.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.90)), + :cov99 => ECCCo.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.99)), ) ``` @@ -369,8 +369,8 @@ using CounterfactualExplanations.Evaluation: benchmark bmks = [] measures = [ CounterfactualExplanations.distance, - ECCCE.distance_from_energy, - ECCCE.distance_from_targets + ECCCo.distance_from_energy, + ECCCo.distance_from_targets ] for (dataname, dataset) in datasets bmk = benchmark( diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index d26f3ca62ce20d7461625df854a428b0e12761c8..5e7f7aa615f8b9bf3f7c549c7c61f349961e3c7c 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -33,7 +33,7 @@ First, let's create a couple of image classifier architectures: epochs = 100 batch_size = minimum([Int(round(n_obs/10)), 128]) n_hidden = 32 -activation = Flux.relu +activation = Flux.swish # builder = MLJFlux.@builder Flux.Chain( # Dense(n_in, n_hidden, activation), # Dense(n_hidden, n_hidden, activation), @@ -43,7 +43,7 @@ activation = Flux.relu # # BatchNorm(n_hidden, activation), # Dense(n_hidden, n_out), # ) -builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.2, σ=activation) +builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.1, σ=activation) # builder = MLJFlux.MLP( # hidden=( # n_hidden, @@ -52,7 +52,7 @@ builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.2, σ=activation) # ), # σ=activation # ) -α = [1.0,1.0,1e-2] +α = [1.0,1.0,1e-1] # Simple MLP: mlp = NeuralNetworkClassifier( @@ -93,13 +93,13 @@ cov = .95 conf_model = conformal_model(jem; method=:adaptive_inductive, coverage=cov) mach = machine(conf_model, X, labels) fit!(mach) -M = ECCCE.ConformalModel(mach.model, mach.fitresult) +M = ECCCo.ConformalModel(mach.model, mach.fitresult) ``` ```{julia} if mach.model.model isa JointEnergyModels.JointEnergyClassifier jem = mach.model.model.jem - n_iter = 100 + n_iter = 500 _w = 1500 plts = [] neach = 10 @@ -151,12 +151,12 @@ ce_jsma = generate_counterfactual( initialization=:identity, ) -# ECCCE: +# ECCCo: λ=[0.0,1.0] temp=0.01 -# Generate counterfactual using ECCCE generator: -generator = ECCCEGenerator( +# Generate counterfactual using ECCCo generator: +generator = CCEGenerator( λ=λ, temp=temp, opt=Flux.Optimise.Adam(), @@ -168,8 +168,8 @@ ce_conformal = generate_counterfactual( converge_when=:generator_conditions, ) -# Generate counterfactual using ECCCE generator: -generator = ECCCEGenerator( +# Generate counterfactual using ECCCo generator: +generator = CCEGenerator( λ=λ, temp=temp, opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0), @@ -191,7 +191,92 @@ p1 = Plots.plot( plts = [p1] ces = [ce_wachter, ce_conformal, ce_jsma, ce_conformal_jsma] -_names = ["Wachter", "ECCCE", "JSMA", "ECCCE-JSMA"] +_names = ["Wachter", "ECCCo", "JSMA", "ECCCo-JSMA"] +for x in zip(ces, _names) + ce, _name = (x[1],x[2]) + x = CounterfactualExplanations.counterfactual(ce) + _phat = target_probs(ce) + _title = "$_name (p̂=$(round(_phat[1]; digits=3)))" + plt = Plots.plot( + convert2image(MNIST, reshape(x,28,28)), + axis=nothing, + size=(img_height, img_height), + title=_title + ) + plts = [plts..., plt] +end +plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) +display(plt) +savefig(plt, joinpath(www_path, "cce_mnist.png")) +``` + +```{julia} +# Random.seed!(1234) + +# Set up search: +factual_label = 8 +x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) +target = 3 +factual = predict_label(M, counterfactual_data, x)[1] +γ = 0.5 +T = 100 + +# Generate counterfactual using generic generator: +generator = GenericGenerator(opt=Flux.Optimise.Adam(),) +ce_wachter = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, +) + +generator = GreedyGenerator(η=1.0) +ce_jsma = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, +) + +# ECCCo: +λ=[0.0,1.0,1.0] +temp=0.01 + +# Generate counterfactual using ECCCo generator: +generator = ECCCoGenerator( + λ=λ, + temp=temp, + opt=Flux.Optimise.Adam(), +) +ce_conformal = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, + converge_when=:generator_conditions, +) + +# Generate counterfactual using ECCCo generator: +generator = ECCCoGenerator( + λ=λ, + temp=temp, + opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0), +) +ce_conformal_jsma = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, + converge_when=:generator_conditions, +) + +# Plot: +p1 = Plots.plot( + convert2image(MNIST, reshape(x,28,28)), + axis=nothing, + size=(img_height, img_height), + title="Factual" +) +plts = [p1] + +ces = [ce_wachter, ce_conformal, ce_jsma, ce_conformal_jsma] +_names = ["Wachter", "ECCCo", "JSMA", "ECCCo-JSMA"] for x in zip(ces, _names) ce, _name = (x[1],x[2]) x = CounterfactualExplanations.counterfactual(ce) @@ -226,8 +311,8 @@ generators = Dict( # Measures: measures = [ CounterfactualExplanations.distance, - ECCCE.distance_from_energy, - ECCCE.distance_from_targets, + ECCCo.distance_from_energy, + ECCCo.distance_from_targets, CounterfactualExplanations.validity, ] ``` \ No newline at end of file diff --git a/notebooks/proposal.qmd b/notebooks/proposal.qmd index bdbd45b34b2cef38cf7889a408930e95cab37ca8..2511087974deb9c0660b8c83da75f998c09f8506 100644 --- a/notebooks/proposal.qmd +++ b/notebooks/proposal.qmd @@ -188,7 +188,7 @@ If the computed value is different from zero, we can reject the null-hypothesis ## Conformal Counterfactual Explanations -In @sec-fidelity, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (ECCCE), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models. +In @sec-fidelity, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (ECCCo), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models. ### Minimizing Predictive Uncertainty diff --git a/notebooks/setup.jl b/notebooks/setup.jl index 5997cae6d12ea9c680bea6c7e9475c6fda393a2a..a377a5880eef6b65f10082c6a59916a23a371f88 100644 --- a/notebooks/setup.jl +++ b/notebooks/setup.jl @@ -6,8 +6,8 @@ setup_notebooks = quote using AlgebraOfGraphics using AlgebraOfGraphics: Violin, BoxPlot, BarPlot using CairoMakie - using ECCCE - using ECCCE: set_size_penalty, distance_from_energy, distance_from_targets + using ECCCo + using ECCCo: set_size_penalty, distance_from_energy, distance_from_targets using Chain: @chain using ConformalPrediction using CounterfactualExplanations diff --git a/notebooks/synthetic.qmd b/notebooks/synthetic.qmd index 0ac30edb952175fd984e4cbdf1c68fcdcc2168d5..620d9cc945dfd2c0f244a4f9539ddfbcf3fce167 100644 --- a/notebooks/synthetic.qmd +++ b/notebooks/synthetic.qmd @@ -60,16 +60,16 @@ for (dataname, data) in datasets conf_model = conformal_model(clf; method=:simple_inductive, coverage=cov) mach = machine(conf_model, X, y) fit!(mach) - M = ECCCE.ConformalModel(mach.model, mach.fitresult) + M = ECCCo.ConformalModel(mach.model, mach.fitresult) - # Set up ECCCE: + # Set up ECCCo: factual_label = predict_label(M, data, x)[1] target_label = data.y_levels[data.y_levels .!= factual_label][1] for λ in Λ, temp in temps - # ECCCE for given classifier, coverage, temperature and λ: - generator = ECCCEGenerator(temp=temp, λ=[l2_λ,λ], opt=opt) + # ECCCo for given classifier, coverage, temperature and λ: + generator = ECCCoGenerator(temp=temp, λ=[l2_λ,λ], opt=opt) @assert predict_label(M, data, x) != target_label ce = try generate_counterfactual( @@ -158,13 +158,13 @@ generators = Dict( ) # Untrained Models: -models = Dict(Symbol("cov$(Int(100*cov))") => ECCCE.ConformalModel(conformal_model(mlp; method=:simple_inductive, coverage=cov)) for cov in cvgs) +models = Dict(Symbol("cov$(Int(100*cov))") => ECCCo.ConformalModel(conformal_model(mlp; method=:simple_inductive, coverage=cov)) for cov in cvgs) # Measures: measures = [ CounterfactualExplanations.distance, - ECCCE.distance_from_energy, - ECCCE.distance_from_targets, + ECCCo.distance_from_energy, + ECCCo.distance_from_targets, CounterfactualExplanations.validity, ] ``` @@ -187,7 +187,7 @@ for (dataname, data) in datasets # Model training: M = train(M, data) - # Set up ECCCE: + # Set up ECCCo: factual_label = predict_label(M, data, x)[1] target_label = data.y_levels[data.y_levels .!= factual_label][1] @@ -195,13 +195,13 @@ for (dataname, data) in datasets # Generators: _generators = deepcopy(generators) - _generators[:cce] = ECCCEGenerator(temp=_temp, λ=[l2_λ,λ], opt=opt) - _generators[:energy] = ECCCE.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt) - _generators[:target] = ECCCE.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt) + _generators[:cce] = ECCCoGenerator(temp=_temp, λ=[l2_λ,λ], opt=opt) + _generators[:energy] = ECCCo.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt) + _generators[:target] = ECCCo.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt) for (gen_name, gen) in _generators - # ECCCE for given models, λ and generator: + # ECCCo for given models, λ and generator: @assert predict_label(M, data, x) != target_label ce = try generate_counterfactual( @@ -300,9 +300,9 @@ bmks = [] for (dataname, dataset) in datasets for λ in Λ, temp in temps _generators = deepcopy(generators) - _generators[:cce] = ECCCEGenerator(temp=temp, λ=[l2_λ,λ], opt=opt) - _generators[:energy] = ECCCE.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt) - _generators[:target] = ECCCE.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt) + _generators[:cce] = ECCCoGenerator(temp=temp, λ=[l2_λ,λ], opt=opt) + _generators[:energy] = ECCCo.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt) + _generators[:target] = ECCCo.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt) bmk = benchmark( dataset; models=deepcopy(models), diff --git a/paper/paper.pdf b/paper/paper.pdf index 6a82e2a3a87f7b983aca74527c0f3c31e0276166..2ba71c2968b35cca31816ccb45be28b071acf7fa 100644 Binary files a/paper/paper.pdf and b/paper/paper.pdf differ diff --git a/paper/paper.tex b/paper/paper.tex index 13799de77ebc15c684582cf5e28ce0d018267804..bd18cf2a571a04e33ac4f7085bbdb62132d7190c 100644 --- a/paper/paper.tex +++ b/paper/paper.tex @@ -45,7 +45,7 @@ \newtheorem{definition}{Definition}[section] -\title{Plausibility isn't all you need: Conformal Counterfactual Explanations} +\title{ECCCos from the Black Box: Letting Models speak for Themselves} % The \author macro works with any number of authors. There are two commands @@ -266,7 +266,7 @@ The fact that conformal classifiers produce set-valued predictions introduces a where $\kappa \in \{0,1\}$ is a hyper-parameter and $C_{\theta,\mathbf{y}}(\mathbf{x}_i;\alpha)$ can be interpreted as the probability of label $\mathbf{y}$ being included in the prediction set. Formally, it is defined as $C_{\theta,\mathbf{y}}(\mathbf{x}_i;\alpha):=\sigma\left((s(\mathbf{x}_i,\mathbf{y})-\alpha) T^{-1}\right)$ for $\mathbf{y}\in\mathcal{Y}$ where $\sigma$ is the sigmoid function and $T$ is a hyper-parameter used for temperature scaling \citep{stutz2022learning}. -Penalizing the set size in this way is in principal enough to train efficient conformal classifiers \citep{stutz2022learning}. As we explained above, the set size is also closely linked to predictive uncertainty at the local level. This makes the smooth penalty defined in Equation~\ref{eq:setsize} useful in the context of meeting our objective of generating plausible counterfactuals. In particular, we adapt Equation~\ref{eq:general} to define the baseline objective for Conformal Counterfactual Explanations (ECCCE): +Penalizing the set size in this way is in principal enough to train efficient conformal classifiers \citep{stutz2022learning}. As we explained above, the set size is also closely linked to predictive uncertainty at the local level. This makes the smooth penalty defined in Equation~\ref{eq:setsize} useful in the context of meeting our objective of generating plausible counterfactuals. In particular, we adapt Equation~\ref{eq:general} to define the baseline objective for Conformal Counterfactual Explanations (ECCCo): \begin{equation}\label{eq:cce} \begin{aligned} @@ -276,7 +276,7 @@ Penalizing the set size in this way is in principal enough to train efficient co Since we can still retrieve unperturbed softmax outputs from our conformal classifier $M_{\theta}$, we are free to work with any loss function of our choice. For example, we could use standard cross-entropy for $\text{yloss}$. -In order to generate prediction sets $C_{\theta}(f(\mathbf{Z}^\prime);\alpha)$ for any Black Box Model we merely need to perform a single calibration pass through a holdout set $\mathcal{D}_{\text{cal}}$. Arguably, data is typically abundant and in most applications practitioners tend to hold out a test data set anyway. Our proposed approach for ECCCE therefore removes the restriction on the family of predictive models, at the small cost of reserving a subset of the available data for calibration. +In order to generate prediction sets $C_{\theta}(f(\mathbf{Z}^\prime);\alpha)$ for any Black Box Model we merely need to perform a single calibration pass through a holdout set $\mathcal{D}_{\text{cal}}$. Arguably, data is typically abundant and in most applications practitioners tend to hold out a test data set anyway. Our proposed approach for ECCCo therefore removes the restriction on the family of predictive models, at the small cost of reserving a subset of the available data for calibration. \section{Experiments} diff --git a/src/ECCCE.jl b/src/ECCCo.jl similarity index 71% rename from src/ECCCE.jl rename to src/ECCCo.jl index 8a031fffa21f2d3f65f938b79725f348dcfc8703..82c901d72e83b0b8ef6abf2ca43d5801c1403489 100644 --- a/src/ECCCE.jl +++ b/src/ECCCo.jl @@ -1,4 +1,4 @@ -module ECCCE +module ECCCo using CounterfactualExplanations import MLJModelInterface as MMI @@ -9,6 +9,6 @@ include("losses.jl") include("generator.jl") include("sampling.jl") -export ECCCEGenerator, EnergySampler, set_size_penalty, distance_from_energy +export ECCCoGenerator, EnergySampler, set_size_penalty, distance_from_energy end \ No newline at end of file diff --git a/src/generator.jl b/src/generator.jl index ac598d483e8a6f246210b4fa47daf436ba508b14..2dc3e460c8bfed0846ecd56f90f6792fd99f3280 100644 --- a/src/generator.jl +++ b/src/generator.jl @@ -1,17 +1,17 @@ using CounterfactualExplanations.Objectives -"Constructor for `ECCCEGenerator`." -function ECCCEGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], κ::Real=1.0, temp::Real=0.05, kwargs...) +"Constructor for `ECCCoGenerator`." +function ECCCoGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], κ::Real=1.0, temp::Real=0.05, kwargs...) function _set_size_penalty(ce::AbstractCounterfactualExplanation) - return ECCCE.set_size_penalty(ce; κ=κ, temp=temp) + return ECCCo.set_size_penalty(ce; κ=κ, temp=temp) end _penalties = [Objectives.distance_l2, _set_size_penalty] λ = λ isa AbstractFloat ? [0.0, λ] : λ return Generator(; penalty=_penalties, λ=λ, kwargs...) end -"Constructor for `ECECCCEGenerator`: Energy Constrained Conformal Counterfactual Explanation Generator." -function ECECCCEGenerator(; +"Constructor for `ECECCCoGenerator`: Energy Constrained Conformal Counterfactual Explanation Generator." +function ECECCCoGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0, 1.0], κ::Real=1.0, temp::Real=0.5, @@ -21,23 +21,23 @@ function ECECCCEGenerator(; kwargs... ) function _set_size_penalty(ce::AbstractCounterfactualExplanation) - return ECCCE.set_size_penalty(ce; κ=κ, temp=temp) + return ECCCo.set_size_penalty(ce; κ=κ, temp=temp) end - _penalties = [Objectives.distance_l2, _set_size_penalty, ECCCE.distance_from_energy] + _penalties = [Objectives.distance_l2, _set_size_penalty, ECCCo.distance_from_energy] λ = λ isa AbstractFloat ? [0.0, λ, λ] : λ return Generator(; penalty=_penalties, λ=λ, opt=opt, kwargs...) end "Constructor for `EnergyDrivenGenerator`." function EnergyDrivenGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], kwargs...) - _penalties = [Objectives.distance_l2, ECCCE.distance_from_energy] + _penalties = [Objectives.distance_l2, ECCCo.distance_from_energy] λ = λ isa AbstractFloat ? [0.0, λ] : λ return Generator(; penalty=_penalties, λ=λ, kwargs...) end "Constructor for `TargetDrivenGenerator`." function TargetDrivenGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], kwargs...) - _penalties = [Objectives.distance_l2, ECCCE.distance_from_targets] + _penalties = [Objectives.distance_l2, ECCCo.distance_from_targets] λ = λ isa AbstractFloat ? [0.0, λ] : λ return Generator(; penalty=_penalties, λ=λ, kwargs...) end \ No newline at end of file diff --git a/src/penalties.jl b/src/penalties.jl index 37474c9e43555dc497bdb9f766ec2915d45d5e63..81b953fe5884a9fec638a2db01df73edcdd3176a 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -42,7 +42,7 @@ function distance_from_energy( ignore_derivatives() do _dict = ce.params if !(:energy_sampler ∈ collect(keys(_dict))) - _dict[:energy_sampler] = ECCCE.EnergySampler(ce; kwargs...) + _dict[:energy_sampler] = ECCCo.EnergySampler(ce; kwargs...) end sampler = _dict[:energy_sampler] push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer)) diff --git a/test/runtests.jl b/test/runtests.jl index 569772e439080dfee852354167d0926c7c60170f..2967ba336a195b729d1d5cbdac4e49de06361719 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ -using ECCCE +using ECCCo using Test -@testset "ECCCE.jl" begin +@testset "ECCCo.jl" begin # Write your tests here. end diff --git a/www/cce_mnist.png b/www/cce_mnist.png index 3db6423c613ebc10174d84b0ae8f6f3cf8601ff0..55ff8d5650856abe43798dcd34a10da9de1c34f0 100644 Binary files a/www/cce_mnist.png and b/www/cce_mnist.png differ