diff --git a/CITATION.bib b/CITATION.bib index d673e1fe2e7d6a364ad2ad098f83f4a366cbd319..752944a71c777c0c19921312b6d83677ded8046e 100644 --- a/CITATION.bib +++ b/CITATION.bib @@ -1,7 +1,7 @@ -@misc{CCE.jl, +@misc{ECCCE.jl, author = {Patrick Altmeyer}, - title = {CCE.jl}, - url = {https://github.com/pat-alt/CCE.jl}, + title = {ECCCE.jl}, + url = {https://github.com/pat-alt/ECCCE.jl}, version = {v0.1.0}, year = {2023}, month = {2} diff --git a/Project.toml b/Project.toml index 8783ebf8911f11f54210c4272232c91c0679b7aa..647598d4a3b7da049887eb85135a2e877ac960dc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,4 +1,4 @@ -name = "CCE" +name = "ECCCE" uuid = "0232c203-4013-4b0d-ad96-43e3e11ac3bf" authors = ["Patrick Altmeyer"] version = "0.1.0" diff --git a/README.md b/README.md index a0fbd4f4ce07925cd900b33ab5f4293570e252df..c599843e1c23d5e35d082b8b9dacad5e971aa00e 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,3 @@ -# CCE +# ECCCE -[](https://github.com/pat-alt/CCE.jl/actions/workflows/CI.yml?query=branch%3Amain) +[](https://github.com/pat-alt/ECCCE.jl/actions/workflows/CI.yml?query=branch%3Amain) diff --git a/_freeze/dev/proposal/execute-results/html.json b/_freeze/dev/proposal/execute-results/html.json index 1d98ce871632c4fdbe1f483ab37133fa1a8ee291..195e16d6cc088edd81dae8c7a7f5ee331e48f895 100644 --- a/_freeze/dev/proposal/execute-results/html.json +++ b/_freeze/dev/proposal/execute-results/html.json @@ -1,7 +1,7 @@ { "hash": "d7b4f9bf7f4bff7ce610fc8be4dcfb8b", "result": { - "markdown": "---\ntitle: High-Fidelity Counterfactual Explanations through Conformal Prediction\nsubtitle: Research Proposal\nabstract: |\n We propose Conformal Counterfactual Explanations: an effortless and rigorous way to produce realistic and faithful Counterfactual Explanations using Conformal Prediction. To address the need for realistic counterfactuals, existing work has primarily relied on separate generative models to learn the data-generating process. While this is an effective way to produce plausible and model-agnostic counterfactual explanations, it not only introduces a significant engineering overhead but also reallocates the task of creating realistic model explanations from the model itself to the generative model. Recent work has shown that there is no need for any of this when working with probabilistic models that explicitly quantify their own uncertainty. Unfortunately, most models used in practice still do not fulfil that basic requirement, in which case we would like to have a way to quantify predictive uncertainty in a post-hoc fashion.\n---\n\n\n\n## Motivation\n\nCounterfactual Explanations are a powerful, flexible and intuitive way to not only explain black-box models but also enable affected individuals to challenge them through the means of Algorithmic Recourse. \n\n### Counterfactual Explanations or Adversarial Examples?\n\nMost state-of-the-art approaches to generating Counterfactual Explanations (CE) rely on gradient descent in the feature space. The key idea is to perturb inputs $x\\in\\mathcal{X}$ into a black-box model $f: \\mathcal{X} \\mapsto \\mathcal{Y}$ in order to change the model output $f(x)$ to some pre-specified target value $t\\in\\mathcal{Y}$. Formally, this boils down to defining some loss function $\\ell(f(x),t)$ and taking gradient steps in the minimizing direction. The so-generated counterfactuals are considered valid as soon as the predicted label matches the target label. A stripped-down counterfactual explanation is therefore little different from an adversarial example. In @fig-adv, for example, generic counterfactual search as in @wachter2017counterfactual has been applied to MNIST data.\n\n\n\n\n\n{#fig-adv}\n\nThe crucial difference between adversarial examples and counterfactuals is one of intent. While adversarial examples are typically intended to go unnoticed, counterfactuals in the context of Explainable AI are generally sought to be \"plausible\", \"realistic\" or \"feasible\". To fulfil this latter goal, researchers have come up with a myriad of ways. @joshi2019realistic were among the first to suggest that instead of searching counterfactuals in the feature space, we can instead traverse a latent embedding learned by a surrogate generative model. Similarly, @poyiadzi2020face use density ... Finally, @karimi2021algorithmic argues that counterfactuals should comply with the causal model that generates them [CHECK IF WE CAN PHASE THIS LIKE THIS]. Other related approaches include ... All of these different approaches have a common goal: they aim to ensure that the generated counterfactuals comply with the (learned) data-generating process (DGB). \n\n::: {#def-plausible}\n\n## Plausible Counterfactuals\n\nFormally, if $x \\sim \\mathcal{X}$ and for the corresponding counterfactual we have $x^{\\prime}\\sim\\mathcal{X}^{\\prime}$, then for $x^{\\prime}$ to be considered a plausible counterfactual, we need: $\\mathcal{X} \\approxeq \\mathcal{X}^{\\prime}$.\n\n:::\n\nIn the context of Algorithmic Recourse, it makes sense to strive for plausible counterfactuals, since anything else would essentially require individuals to move to out-of-distribution states. But it is worth noting that our ambition to meet this goal, may have implications on our ability to faithfully explain the behaviour of the underlying black-box model (arguably our principal goal). By essentially decoupling the task of learning plausible representations of the data from the model itself, we open ourselves up to vulnerabilities. Using a separate generative model to learn $\\mathcal{X}$, for example, has very serious implications for the generated counterfactuals. @fig-latent compares the results of applying REVISE [@joshi2019realistic] to MNIST data using two different Variational Auto-Encoders: while the counterfactual generated using an expressive (strong) VAE is compelling, the result relying on a less expressive (weak) VAE is not even valid. In this latter case, the decoder step of the VAE fails to yield values in $\\mathcal{X}$ and hence the counterfactual search in the learned latent space is doomed. \n\n{#fig-latent}\n\n> Here it would be nice to have another example where we poison the data going into the generative model to hide biases present in the data (e.g. Boston housing).\n\n- Latent can be manipulated: \n - train biased model\n - train VAE with biased variable removed/attacked (use Boston housing dataset)\n - hypothesis: will generate bias-free explanations\n\n### From Plausible to High-Fidelity Counterfactuals {#sec-fidelity}\n\nIn light of the findings, we propose to generally avoid using surrogate models to learn $\\mathcal{X}$ in the context of Counterfactual Explanations.\n\n::: {#prp-surrogate}\n\n## Avoid Surrogates\n\nSince we are in the business of explaining a black-box model, the task of learning realistic representations of the data should not be reallocated from the model itself to some surrogate model.\n\n:::\n\nIn cases where the use of surrogate models cannot be avoided, we propose to weigh the plausibility of counterfactuals against their fidelity to the black-box model. In the context of Explainable AI, fidelity is defined as describing how an explanation approximates the prediction of the black-box model [@molnar2020interpretable]. Fidelity has become the default metric for evaluating Local Model-Agnostic Models, since they often involve local surrogate models whose predictions need not always match those of the black-box model. \n\nIn the case of Counterfactual Explanations, the concept of fidelity has so far been ignored. This is not altogether surprising, since by construction and design, Counterfactual Explanations work with the predictions of the black-box model directly: as stated above, a counterfactual $x^{\\prime}$ is considered valid if and only if $f(x^{\\prime})=t$, where $t$ denote some target outcome. \n\nDoes fidelity even make sense in the context of CE, and if so, how can we define it? In light of the examples in the previous section, we think it is urgent to introduce a notion of fidelity in this context, that relates to the distributional properties of the generated counterfactuals. In particular, we propose that a high-fidelity counterfactual $x^{\\prime}$ complies with the class-conditional distribution $\\mathcal{X}_{\\theta} = p_{\\theta}(X|y)$ where $\\theta$ denote the black-box model parameters. \n\n::: {#def-fidele}\n\n## High-Fidelity Counterfactuals\n\nLet $\\mathcal{X}_{\\theta}|y = p_{\\theta}(X|y)$ denote the class-conditional distribution of $X$ defined by $\\theta$. Then for $x^{\\prime}$ to be considered a high-fidelity counterfactual, we need: $\\mathcal{X}_{\\theta}|t \\approxeq \\mathcal{X}^{\\prime}$ where $t$ denotes the target outcome.\n\n:::\n\nIn order to assess the fidelity of counterfactuals, we propose the following two-step procedure:\n\n1) Generate samples $X_{\\theta}|y$ and $X^{\\prime}$ from $\\mathcal{X}_{\\theta}|t$ and $\\mathcal{X}^{\\prime}$, respectively.\n2) Compute the Maximum Mean Discrepancy (MMD) between $X_{\\theta}|y$ and $X^{\\prime}$. \n\nIf the computed value is different from zero, we can reject the null-hypothesis of fidelity.\n\n> Two challenges here: 1) implementing the sampling procedure in @grathwohl2020your; 2) it is unclear if MMD is really the right way to measure this. \n\n## Conformal Counterfactual Explanations\n\nIn @sec-fidelity, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (CCE), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models. \n\n### Minimizing Predictive Uncertainty\n\n@schut2021generating demonstrated that the goal of generating realistic (plausible) counterfactuals can also be achieved by seeking counterfactuals that minimize the predictive uncertainty of the underlying black-box model. Similarly, @antoran2020getting ...\n\n- Problem: restricted to Bayesian models.\n- Solution: post-hoc predictive uncertainty quantification. In particular, Conformal Prediction. \n\n### Background on Conformal Prediction\n\n- Distribution-free, model-agnostic and scalable approach to predictive uncertainty quantification.\n- Conformal prediction is instance-based. So is CE. \n- Take any fitted model and turn it into a conformal model using calibration data.\n- Our approach, therefore, relaxes the restriction on the family of black-box models, at the cost of relying on a subset of the data. Arguably, data is often abundant and in most applications practitioners tend to hold out a test data set anyway. \n\n> Does the coverage guarantee carry over to counterfactuals?\n\n### Generating Conformal Counterfactuals\n\nWhile Conformal Prediction has recently grown in popularity, it does introduce a challenge in the context of classification: the predictions of Conformal Classifiers are set-valued and therefore difficult to work with, since they are, for example, non-differentiable. Fortunately, @stutz2022learning introduced carefully designed differentiable loss functions that make it possible to evaluate the performance of conformal predictions in training. We can leverage these recent advances in the context of gradient-based counterfactual search ...\n\n> Challenge: still need to implement these loss functions. \n\n## Experiments\n\n### Research Questions\n\n- Is CP alone enough to ensure realistic counterfactuals?\n- Do counterfactuals improve further as the models get better?\n- Do counterfactuals get more realistic as coverage\n- What happens as we vary coverage and setsize?\n- What happens as we improve the model robustness?\n- What happens as we improve the model's ability to incorporate predictive uncertainty (deep ensemble, laplace)?\n- What happens if we combine with DiCE, ClaPROAR, Gravitational?\n- What about CE robustness to endogenous shifts [@altmeyer2023endogenous]?\n\n- Benchmarking:\n - add PROBE [@pawelczyk2022probabilistically] into the mix.\n - compare travel costs to domain shits.\n\n> Nice to have: What about using Laplace Approximation, then Conformal Prediction? What about using Conformalised Laplace? \n\n## References\n\n", + "markdown": "---\ntitle: High-Fidelity Counterfactual Explanations through Conformal Prediction\nsubtitle: Research Proposal\nabstract: |\n We propose Conformal Counterfactual Explanations: an effortless and rigorous way to produce realistic and faithful Counterfactual Explanations using Conformal Prediction. To address the need for realistic counterfactuals, existing work has primarily relied on separate generative models to learn the data-generating process. While this is an effective way to produce plausible and model-agnostic counterfactual explanations, it not only introduces a significant engineering overhead but also reallocates the task of creating realistic model explanations from the model itself to the generative model. Recent work has shown that there is no need for any of this when working with probabilistic models that explicitly quantify their own uncertainty. Unfortunately, most models used in practice still do not fulfil that basic requirement, in which case we would like to have a way to quantify predictive uncertainty in a post-hoc fashion.\n---\n\n\n\n## Motivation\n\nCounterfactual Explanations are a powerful, flexible and intuitive way to not only explain black-box models but also enable affected individuals to challenge them through the means of Algorithmic Recourse. \n\n### Counterfactual Explanations or Adversarial Examples?\n\nMost state-of-the-art approaches to generating Counterfactual Explanations (CE) rely on gradient descent in the feature space. The key idea is to perturb inputs $x\\in\\mathcal{X}$ into a black-box model $f: \\mathcal{X} \\mapsto \\mathcal{Y}$ in order to change the model output $f(x)$ to some pre-specified target value $t\\in\\mathcal{Y}$. Formally, this boils down to defining some loss function $\\ell(f(x),t)$ and taking gradient steps in the minimizing direction. The so-generated counterfactuals are considered valid as soon as the predicted label matches the target label. A stripped-down counterfactual explanation is therefore little different from an adversarial example. In @fig-adv, for example, generic counterfactual search as in @wachter2017counterfactual has been applied to MNIST data.\n\n\n\n\n\n{#fig-adv}\n\nThe crucial difference between adversarial examples and counterfactuals is one of intent. While adversarial examples are typically intended to go unnoticed, counterfactuals in the context of Explainable AI are generally sought to be \"plausible\", \"realistic\" or \"feasible\". To fulfil this latter goal, researchers have come up with a myriad of ways. @joshi2019realistic were among the first to suggest that instead of searching counterfactuals in the feature space, we can instead traverse a latent embedding learned by a surrogate generative model. Similarly, @poyiadzi2020face use density ... Finally, @karimi2021algorithmic argues that counterfactuals should comply with the causal model that generates them [CHECK IF WE CAN PHASE THIS LIKE THIS]. Other related approaches include ... All of these different approaches have a common goal: they aim to ensure that the generated counterfactuals comply with the (learned) data-generating process (DGB). \n\n::: {#def-plausible}\n\n## Plausible Counterfactuals\n\nFormally, if $x \\sim \\mathcal{X}$ and for the corresponding counterfactual we have $x^{\\prime}\\sim\\mathcal{X}^{\\prime}$, then for $x^{\\prime}$ to be considered a plausible counterfactual, we need: $\\mathcal{X} \\approxeq \\mathcal{X}^{\\prime}$.\n\n:::\n\nIn the context of Algorithmic Recourse, it makes sense to strive for plausible counterfactuals, since anything else would essentially require individuals to move to out-of-distribution states. But it is worth noting that our ambition to meet this goal, may have implications on our ability to faithfully explain the behaviour of the underlying black-box model (arguably our principal goal). By essentially decoupling the task of learning plausible representations of the data from the model itself, we open ourselves up to vulnerabilities. Using a separate generative model to learn $\\mathcal{X}$, for example, has very serious implications for the generated counterfactuals. @fig-latent compares the results of applying REVISE [@joshi2019realistic] to MNIST data using two different Variational Auto-Encoders: while the counterfactual generated using an expressive (strong) VAE is compelling, the result relying on a less expressive (weak) VAE is not even valid. In this latter case, the decoder step of the VAE fails to yield values in $\\mathcal{X}$ and hence the counterfactual search in the learned latent space is doomed. \n\n{#fig-latent}\n\n> Here it would be nice to have another example where we poison the data going into the generative model to hide biases present in the data (e.g. Boston housing).\n\n- Latent can be manipulated: \n - train biased model\n - train VAE with biased variable removed/attacked (use Boston housing dataset)\n - hypothesis: will generate bias-free explanations\n\n### From Plausible to High-Fidelity Counterfactuals {#sec-fidelity}\n\nIn light of the findings, we propose to generally avoid using surrogate models to learn $\\mathcal{X}$ in the context of Counterfactual Explanations.\n\n::: {#prp-surrogate}\n\n## Avoid Surrogates\n\nSince we are in the business of explaining a black-box model, the task of learning realistic representations of the data should not be reallocated from the model itself to some surrogate model.\n\n:::\n\nIn cases where the use of surrogate models cannot be avoided, we propose to weigh the plausibility of counterfactuals against their fidelity to the black-box model. In the context of Explainable AI, fidelity is defined as describing how an explanation approximates the prediction of the black-box model [@molnar2020interpretable]. Fidelity has become the default metric for evaluating Local Model-Agnostic Models, since they often involve local surrogate models whose predictions need not always match those of the black-box model. \n\nIn the case of Counterfactual Explanations, the concept of fidelity has so far been ignored. This is not altogether surprising, since by construction and design, Counterfactual Explanations work with the predictions of the black-box model directly: as stated above, a counterfactual $x^{\\prime}$ is considered valid if and only if $f(x^{\\prime})=t$, where $t$ denote some target outcome. \n\nDoes fidelity even make sense in the context of CE, and if so, how can we define it? In light of the examples in the previous section, we think it is urgent to introduce a notion of fidelity in this context, that relates to the distributional properties of the generated counterfactuals. In particular, we propose that a high-fidelity counterfactual $x^{\\prime}$ complies with the class-conditional distribution $\\mathcal{X}_{\\theta} = p_{\\theta}(X|y)$ where $\\theta$ denote the black-box model parameters. \n\n::: {#def-fidele}\n\n## High-Fidelity Counterfactuals\n\nLet $\\mathcal{X}_{\\theta}|y = p_{\\theta}(X|y)$ denote the class-conditional distribution of $X$ defined by $\\theta$. Then for $x^{\\prime}$ to be considered a high-fidelity counterfactual, we need: $\\mathcal{X}_{\\theta}|t \\approxeq \\mathcal{X}^{\\prime}$ where $t$ denotes the target outcome.\n\n:::\n\nIn order to assess the fidelity of counterfactuals, we propose the following two-step procedure:\n\n1) Generate samples $X_{\\theta}|y$ and $X^{\\prime}$ from $\\mathcal{X}_{\\theta}|t$ and $\\mathcal{X}^{\\prime}$, respectively.\n2) Compute the Maximum Mean Discrepancy (MMD) between $X_{\\theta}|y$ and $X^{\\prime}$. \n\nIf the computed value is different from zero, we can reject the null-hypothesis of fidelity.\n\n> Two challenges here: 1) implementing the sampling procedure in @grathwohl2020your; 2) it is unclear if MMD is really the right way to measure this. \n\n## Conformal Counterfactual Explanations\n\nIn @sec-fidelity, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (ECCCE), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models. \n\n### Minimizing Predictive Uncertainty\n\n@schut2021generating demonstrated that the goal of generating realistic (plausible) counterfactuals can also be achieved by seeking counterfactuals that minimize the predictive uncertainty of the underlying black-box model. Similarly, @antoran2020getting ...\n\n- Problem: restricted to Bayesian models.\n- Solution: post-hoc predictive uncertainty quantification. In particular, Conformal Prediction. \n\n### Background on Conformal Prediction\n\n- Distribution-free, model-agnostic and scalable approach to predictive uncertainty quantification.\n- Conformal prediction is instance-based. So is CE. \n- Take any fitted model and turn it into a conformal model using calibration data.\n- Our approach, therefore, relaxes the restriction on the family of black-box models, at the cost of relying on a subset of the data. Arguably, data is often abundant and in most applications practitioners tend to hold out a test data set anyway. \n\n> Does the coverage guarantee carry over to counterfactuals?\n\n### Generating Conformal Counterfactuals\n\nWhile Conformal Prediction has recently grown in popularity, it does introduce a challenge in the context of classification: the predictions of Conformal Classifiers are set-valued and therefore difficult to work with, since they are, for example, non-differentiable. Fortunately, @stutz2022learning introduced carefully designed differentiable loss functions that make it possible to evaluate the performance of conformal predictions in training. We can leverage these recent advances in the context of gradient-based counterfactual search ...\n\n> Challenge: still need to implement these loss functions. \n\n## Experiments\n\n### Research Questions\n\n- Is CP alone enough to ensure realistic counterfactuals?\n- Do counterfactuals improve further as the models get better?\n- Do counterfactuals get more realistic as coverage\n- What happens as we vary coverage and setsize?\n- What happens as we improve the model robustness?\n- What happens as we improve the model's ability to incorporate predictive uncertainty (deep ensemble, laplace)?\n- What happens if we combine with DiCE, ClaPROAR, Gravitational?\n- What about CE robustness to endogenous shifts [@altmeyer2023endogenous]?\n\n- Benchmarking:\n - add PROBE [@pawelczyk2022probabilistically] into the mix.\n - compare travel costs to domain shits.\n\n> Nice to have: What about using Laplace Approximation, then Conformal Prediction? What about using Conformalised Laplace? \n\n## References\n\n", "supporting": [ "proposal_files/figure-html" ], diff --git a/_freeze/notebooks/intro/execute-results/html.json b/_freeze/notebooks/intro/execute-results/html.json index d600296a7a0fc638125faa2fcfcd7a454f3a84cf..28ce12ee42a881b7ab0755962121c93e8e6fb4bc 100644 --- a/_freeze/notebooks/intro/execute-results/html.json +++ b/_freeze/notebooks/intro/execute-results/html.json @@ -1,7 +1,7 @@ { "hash": "43d5045964ca39def434cb65914681bc", "result": { - "markdown": "::: {.cell execution_count=1}\n``` {.julia .cell-code}\ninclude(\"notebooks/setup.jl\")\neval(setup_notebooks)\n```\n:::\n\n\n# `ConformalGenerator`\n\nIn this section, we will look at a simple example involving synthetic data, a black-box model and a generic Conformal Counterfactual Generator.\n\n## Black-box Model\n\nWe consider a simple binary classification problem. Let $(X_i, Y_i), \\ i=1,...,n$ denote our feature-label pairs and let $\\mu: \\mathcal{X} \\mapsto \\mathcal{Y}$ denote the mapping from features to labels. For illustration purposes, we will use linearly separable data. \n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\ncounterfactual_data = load_linearly_separable()\n```\n:::\n\n\nWhile we could use a linear classifier in this case, let's pretend we need a black-box model for this task and rely on a small Multi-Layer Perceptron (MLP):\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\nbuilder = MLJFlux.@builder Flux.Chain(\n Dense(n_in, 32, relu),\n Dense(32, n_out)\n)\nclf = NeuralNetworkClassifier(builder=builder, epochs=100)\n```\n:::\n\n\nWe can fit this model to data to produce plug-in predictions. \n\n## Conformal Prediction\n\nHere we will instead use a specific case of CP called *split conformal prediction* which can then be summarized as follows:^[In other places split conformal prediction is sometimes referred to as *inductive* conformal prediction.]\n\n1. Partition the training into a proper training set and a separate calibration set: $\\mathcal{D}_n=\\mathcal{D}^{\\text{train}} \\cup \\mathcal{D}^{\\text{cali}}$.\n2. Train the machine learning model on the proper training set: $\\hat\\mu_{i \\in \\mathcal{D}^{\\text{train}}}(X_i,Y_i)$.\n\nThe model $\\hat\\mu_{i \\in \\mathcal{D}^{\\text{train}}}$ can now produce plug-in predictions. \n\n::: callout-note\n\n## Starting Point\n\nNote that this represents the starting point in applications of Algorithmic Recourse: we have some pre-trained classifier $M$ for which we would like to generate plausible Counterfactual Explanations. Next, we turn to the calibration step. \n:::\n\n3. Compute nonconformity scores, $\\mathcal{S}$, using the calibration data $\\mathcal{D}^{\\text{cali}}$ and the fitted model $\\hat\\mu_{i \\in \\mathcal{D}^{\\text{train}}}$. \n4. For a user-specified desired coverage ratio $(1-\\alpha)$ compute the corresponding quantile, $\\hat{q}$, of the empirical distribution of nonconformity scores, $\\mathcal{S}$.\n5. For the given quantile and test sample $X_{\\text{test}}$, form the corresponding conformal prediction set: \n\n$$\nC(X_{\\text{test}})=\\{y:s(X_{\\text{test}},y) \\le \\hat{q}\\}\n$$ {#eq-set}\n\nThis is the default procedure used for classification and regression in [`ConformalPrediction.jl`](https://github.com/pat-alt/ConformalPrediction.jl). \n\nUsing the package, we can apply Split Conformal Prediction as follows:\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\nX = table(permutedims(counterfactual_data.X))\ny = counterfactual_data.output_encoder.labels\nconf_model = conformal_model(clf; method=:simple_inductive)\nmach = machine(conf_model, X, y)\nfit!(mach)\n```\n:::\n\n\nTo be clear, all of the calibration steps (3 to 5) are post hoc, and yet none of them involved any changes to the model parameters. These are two important characteristics of Split Conformal Prediction (SCP) that make it particularly useful in the context of Algorithmic Recourse. Firstly, the fact that SCP involves posthoc calibration steps that happen after training, ensures that we need not place any restrictions on the black-box model itself. This stands in contrast to the approach proposed by @schut2021generating in which they essentially restrict the class of models to Bayesian models. Secondly, the fact that the model itself is kept entirely intact ensures that the generated counterfactuals maintain fidelity to the model. Finally, note that we also have not resorted to a surrogate model to learn more about $X \\sim \\mathcal{X}$. Instead, we have used the fitted model itself and a calibration data set to learn about the model's predictive uncertainty. \n\n## Differentiable CP\n\nIn order to use CP in the context of gradient-based counterfactual search, we need it to be differentiable. @stutz2022learning introduce a framework for training differentiable conformal predictors. They introduce a configurable loss function as well as smooth set size penalty.\n\n### Smooth Set Size Penalty\n\nStarting with the former, @stutz2022learning propose the following:\n\n$$\n\\Omega(C_{\\theta}(x;\\tau)) = = \\max (0, \\sum_k C_{\\theta,k}(x;\\tau) - \\kappa)\n$$ {#eq-size-loss}\n\nHere, $C_{\\theta,k}(x;\\tau)$ is loosely defined as the probability that class $k$ is assigned to the conformal prediction set $C$. In the context of Conformal Training, this penalty reduces the **inefficiency** of the conformal predictor. \n\nIn our context, we are not interested in improving the model itself, but rather in producing **plausible** counterfactuals. Provided that our counterfactual $x^\\prime$ is already inside the target domain ($\\mathbb{I}_{y^\\prime = t}=1$), penalizing $\\Omega(C_{\\theta}(x;\\tau))$ corresponds to guiding counterfactuals into regions of the target domain that are characterized by low ambiguity: for $\\kappa=1$ the conformal prediction set includes only the target label $t$ as $\\Omega(C_{\\theta}(x;\\tau))$. Arguably, less ambiguous counterfactuals are more **plausible**. Since the search is guided purely by properties of the model itself and (exchangeable) calibration data, counterfactuals also maintain **high fidelity**.\n\nThe left panel of @fig-losses shows the smooth size penalty in the two-dimensional feature space of our synthetic data.\n\n### Configurable Classification Loss\n\nThe right panel of @fig-losses shows the configurable classification loss in the two-dimensional feature space of our synthetic data.\n\n::: {.cell execution_count=5}\n\n::: {.cell-output .cell-output-display execution_count=6}\n{#fig-losses}\n:::\n:::\n\n\n## Fidelity and Plausibility\n\nThe main evaluation criteria we are interested in are *fidelity* and *plausibility*. Interestingly, we could also consider using these measures as penalties in the counterfactual search.\n\n### Fidelity\n\nWe propose to define fidelity as follows:\n\n::: {#def-fidelity}\n\n## High-Fidelity Counterfactuals\n\nLet $\\mathcal{X}_{\\theta}|y = p_{\\theta}(X|y)$ denote the class-conditional distribution of $X$ defined by $\\theta$. Then for $x^{\\prime}$ to be considered a high-fidelity counterfactual, we need: $\\mathcal{X}_{\\theta}|t \\approxeq \\mathcal{X}^{\\prime}$ where $t$ denotes the target outcome.\n\n:::\n\nWe can generate samples from $p_{\\theta}(X|y)$ following @grathwohl2020your. In @fig-energy, I have applied the methodology to our synthetic data.\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\nM = CCE.ConformalModel(conf_model, mach.fitresult)\n\nniter = 100\nnsamples = 100\n\nplts = []\nfor (i,target) ∈ enumerate(counterfactual_data.y_levels)\n sampler = CCE.EnergySampler(M, counterfactual_data, target; niter=niter, nsamples=100)\n Xgen = rand(sampler, nsamples)\n plt = Plots.plot(M, counterfactual_data; target=target, zoom=-3,cbar=false)\n Plots.scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=i,shape=:star,label=\"X|y=$target\")\n push!(plts, plt)\nend\nPlots.plot(plts..., layout=(1,length(plts)), size=(img_height*length(plts),img_height))\n```\n\n::: {.cell-output .cell-output-display execution_count=7}\n{#fig-energy}\n:::\n:::\n\n\nAs an evaluation metric and penalty, we could use the average distance of the counterfactual $x^{\\prime}$ from these generated samples, for example.\n\n### Plausibility\n\nWe propose to define plausibility as follows:\n\n::: {#def-plausible}\n\n## Plausible Counterfactuals\n\nFormally, let $\\mathcal{X}|t$ denote the conditional distribution of samples in the target class. As before, we have $x^{\\prime}\\sim\\mathcal{X}^{\\prime}$, then for $x^{\\prime}$ to be considered a plausible counterfactual, we need: $\\mathcal{X}|t \\approxeq \\mathcal{X}^{\\prime}$.\n\n:::\n\nAs an evaluation metric and penalty, we could use the average distance of the counterfactual $x^{\\prime}$ from (potentially bootstrapped) training samples in the target class, for example.\n\n## Counterfactual Explanations\n\nNext, let's generate counterfactual explanations for our synthetic data. We first wrap our model in a container that makes it compatible with `CounterfactualExplanations.jl`. Then we draw a random sample, determine its predicted label $\\hat{y}$ and choose the opposite label as our target. \n\n::: {.cell execution_count=7}\n``` {.julia .cell-code}\nx = select_factual(counterfactual_data,rand(1:size(counterfactual_data.X,2)))\ny_factual = predict_label(M, counterfactual_data, x)[1]\ntarget = counterfactual_data.y_levels[counterfactual_data.y_levels .!= y_factual][1]\n```\n:::\n\n\nThe generic Conformal Counterfactual Generator penalises the only the set size only:\n\n$$\nx^\\prime = \\arg \\min_{x^\\prime} \\ell(M(x^\\prime),t) + \\lambda \\mathbb{I}_{y^\\prime = t} \\Omega(C_{\\theta}(x;\\tau)) \n$$ {#eq-solution}\n\n::: {.cell execution_count=8}\n\n::: {.cell-output .cell-output-display execution_count=9}\n{#fig-ce}\n:::\n:::\n\n\n## Multi-Class\n\n::: {.cell execution_count=9}\n``` {.julia .cell-code}\ncounterfactual_data = load_multi_class()\n```\n:::\n\n\n::: {.cell execution_count=10}\n``` {.julia .cell-code}\nX = table(permutedims(counterfactual_data.X))\ny = counterfactual_data.output_encoder.labels\n```\n:::\n\n\n::: {.cell execution_count=11}\n\n::: {.cell-output .cell-output-display execution_count=12}\n{#fig-pen-multi}\n:::\n:::\n\n\n::: {.cell execution_count=12}\n\n::: {.cell-output .cell-output-display execution_count=13}\n{#fig-losses-multi}\n:::\n:::\n\n\n::: {.cell execution_count=13}\n\n::: {.cell-output .cell-output-display execution_count=14}\n{#fig-energy-multi}\n:::\n:::\n\n\n::: {.cell execution_count=14}\n``` {.julia .cell-code}\nx = select_factual(counterfactual_data,rand(1:size(counterfactual_data.X,2)))\ny_factual = predict_label(M, counterfactual_data, x)[1]\ntarget = counterfactual_data.y_levels[counterfactual_data.y_levels .!= y_factual][1]\n```\n:::\n\n\n::: {.cell execution_count=15}\n\n::: {.cell-output .cell-output-display execution_count=16}\n{#fig-ce-multi}\n:::\n:::\n\n\n## Benchmarks\n\n::: {.cell execution_count=16}\n``` {.julia .cell-code}\n# Data:\ndatasets = Dict(\n :linearly_separable => load_linearly_separable(),\n :overlapping => load_overlapping(),\n :moons => load_moons(),\n :circles => load_circles(),\n :multi_class => load_multi_class(),\n)\n\n# Untrained Models:\nmodels = Dict(\n :cov75 => CCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.75)),\n :cov80 => CCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.80)),\n :cov90 => CCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.90)),\n :cov99 => CCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.99)),\n)\n```\n:::\n\n\nThen we can simply loop over the datasets and eventually concatenate the results like so:\n\n::: {.cell execution_count=17}\n``` {.julia .cell-code}\nusing CounterfactualExplanations.Evaluation: benchmark\nbmks = []\nmeasures = [\n CounterfactualExplanations.distance,\n CCE.distance_from_energy,\n CCE.distance_from_targets\n]\nfor (dataname, dataset) in datasets\n bmk = benchmark(\n dataset; \n models=deepcopy(models), \n generators=generators, \n measure=measures,\n suppress_training=false, dataname=dataname,\n n_individuals=10\n )\n push!(bmks, bmk)\nend\nbmk = reduce(vcat, bmks)\n```\n:::\n\n\n::: {.cell execution_count=18}\n``` {.julia .cell-code}\nf(ce) = CounterfactualExplanations.model_evaluation(ce.M, ce.data)\n@chain bmk() begin\n @group_by(model, generator, dataname, variable)\n @select(model, generator, dataname, ce, value)\n @mutate(performance = f(ce))\n @summarize(model=unique(model), generator=unique(generator), dataname=unique(dataname), performace=unique(performance), value=mean(value))\n @ungroup\n @filter(dataname == :multi_class)\n @filter(model == :cov99)\n @filter(variable == \"distance\")\nend\n```\n:::\n\n\n::: {#fig-benchmark .cell execution_count=19}\n\n::: {.cell-output .cell-output-display}\n{#fig-benchmark-1}\n:::\n\n::: {.cell-output .cell-output-display}\n{#fig-benchmark-2}\n:::\n\n::: {.cell-output .cell-output-display}\n{#fig-benchmark-3}\n:::\n\n::: {.cell-output .cell-output-display}\n{#fig-benchmark-4}\n:::\n\n::: {.cell-output .cell-output-display}\n{#fig-benchmark-5}\n:::\n\nBenchmark results for the different generators.\n:::\n\n\n", + "markdown": "::: {.cell execution_count=1}\n``` {.julia .cell-code}\ninclude(\"notebooks/setup.jl\")\neval(setup_notebooks)\n```\n:::\n\n\n# `ConformalGenerator`\n\nIn this section, we will look at a simple example involving synthetic data, a black-box model and a generic Conformal Counterfactual Generator.\n\n## Black-box Model\n\nWe consider a simple binary classification problem. Let $(X_i, Y_i), \\ i=1,...,n$ denote our feature-label pairs and let $\\mu: \\mathcal{X} \\mapsto \\mathcal{Y}$ denote the mapping from features to labels. For illustration purposes, we will use linearly separable data. \n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\ncounterfactual_data = load_linearly_separable()\n```\n:::\n\n\nWhile we could use a linear classifier in this case, let's pretend we need a black-box model for this task and rely on a small Multi-Layer Perceptron (MLP):\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\nbuilder = MLJFlux.@builder Flux.Chain(\n Dense(n_in, 32, relu),\n Dense(32, n_out)\n)\nclf = NeuralNetworkClassifier(builder=builder, epochs=100)\n```\n:::\n\n\nWe can fit this model to data to produce plug-in predictions. \n\n## Conformal Prediction\n\nHere we will instead use a specific case of CP called *split conformal prediction* which can then be summarized as follows:^[In other places split conformal prediction is sometimes referred to as *inductive* conformal prediction.]\n\n1. Partition the training into a proper training set and a separate calibration set: $\\mathcal{D}_n=\\mathcal{D}^{\\text{train}} \\cup \\mathcal{D}^{\\text{cali}}$.\n2. Train the machine learning model on the proper training set: $\\hat\\mu_{i \\in \\mathcal{D}^{\\text{train}}}(X_i,Y_i)$.\n\nThe model $\\hat\\mu_{i \\in \\mathcal{D}^{\\text{train}}}$ can now produce plug-in predictions. \n\n::: callout-note\n\n## Starting Point\n\nNote that this represents the starting point in applications of Algorithmic Recourse: we have some pre-trained classifier $M$ for which we would like to generate plausible Counterfactual Explanations. Next, we turn to the calibration step. \n:::\n\n3. Compute nonconformity scores, $\\mathcal{S}$, using the calibration data $\\mathcal{D}^{\\text{cali}}$ and the fitted model $\\hat\\mu_{i \\in \\mathcal{D}^{\\text{train}}}$. \n4. For a user-specified desired coverage ratio $(1-\\alpha)$ compute the corresponding quantile, $\\hat{q}$, of the empirical distribution of nonconformity scores, $\\mathcal{S}$.\n5. For the given quantile and test sample $X_{\\text{test}}$, form the corresponding conformal prediction set: \n\n$$\nC(X_{\\text{test}})=\\{y:s(X_{\\text{test}},y) \\le \\hat{q}\\}\n$$ {#eq-set}\n\nThis is the default procedure used for classification and regression in [`ConformalPrediction.jl`](https://github.com/pat-alt/ConformalPrediction.jl). \n\nUsing the package, we can apply Split Conformal Prediction as follows:\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\nX = table(permutedims(counterfactual_data.X))\ny = counterfactual_data.output_encoder.labels\nconf_model = conformal_model(clf; method=:simple_inductive)\nmach = machine(conf_model, X, y)\nfit!(mach)\n```\n:::\n\n\nTo be clear, all of the calibration steps (3 to 5) are post hoc, and yet none of them involved any changes to the model parameters. These are two important characteristics of Split Conformal Prediction (SCP) that make it particularly useful in the context of Algorithmic Recourse. Firstly, the fact that SCP involves posthoc calibration steps that happen after training, ensures that we need not place any restrictions on the black-box model itself. This stands in contrast to the approach proposed by @schut2021generating in which they essentially restrict the class of models to Bayesian models. Secondly, the fact that the model itself is kept entirely intact ensures that the generated counterfactuals maintain fidelity to the model. Finally, note that we also have not resorted to a surrogate model to learn more about $X \\sim \\mathcal{X}$. Instead, we have used the fitted model itself and a calibration data set to learn about the model's predictive uncertainty. \n\n## Differentiable CP\n\nIn order to use CP in the context of gradient-based counterfactual search, we need it to be differentiable. @stutz2022learning introduce a framework for training differentiable conformal predictors. They introduce a configurable loss function as well as smooth set size penalty.\n\n### Smooth Set Size Penalty\n\nStarting with the former, @stutz2022learning propose the following:\n\n$$\n\\Omega(C_{\\theta}(x;\\tau)) = = \\max (0, \\sum_k C_{\\theta,k}(x;\\tau) - \\kappa)\n$$ {#eq-size-loss}\n\nHere, $C_{\\theta,k}(x;\\tau)$ is loosely defined as the probability that class $k$ is assigned to the conformal prediction set $C$. In the context of Conformal Training, this penalty reduces the **inefficiency** of the conformal predictor. \n\nIn our context, we are not interested in improving the model itself, but rather in producing **plausible** counterfactuals. Provided that our counterfactual $x^\\prime$ is already inside the target domain ($\\mathbb{I}_{y^\\prime = t}=1$), penalizing $\\Omega(C_{\\theta}(x;\\tau))$ corresponds to guiding counterfactuals into regions of the target domain that are characterized by low ambiguity: for $\\kappa=1$ the conformal prediction set includes only the target label $t$ as $\\Omega(C_{\\theta}(x;\\tau))$. Arguably, less ambiguous counterfactuals are more **plausible**. Since the search is guided purely by properties of the model itself and (exchangeable) calibration data, counterfactuals also maintain **high fidelity**.\n\nThe left panel of @fig-losses shows the smooth size penalty in the two-dimensional feature space of our synthetic data.\n\n### Configurable Classification Loss\n\nThe right panel of @fig-losses shows the configurable classification loss in the two-dimensional feature space of our synthetic data.\n\n::: {.cell execution_count=5}\n\n::: {.cell-output .cell-output-display execution_count=6}\n{#fig-losses}\n:::\n:::\n\n\n## Fidelity and Plausibility\n\nThe main evaluation criteria we are interested in are *fidelity* and *plausibility*. Interestingly, we could also consider using these measures as penalties in the counterfactual search.\n\n### Fidelity\n\nWe propose to define fidelity as follows:\n\n::: {#def-fidelity}\n\n## High-Fidelity Counterfactuals\n\nLet $\\mathcal{X}_{\\theta}|y = p_{\\theta}(X|y)$ denote the class-conditional distribution of $X$ defined by $\\theta$. Then for $x^{\\prime}$ to be considered a high-fidelity counterfactual, we need: $\\mathcal{X}_{\\theta}|t \\approxeq \\mathcal{X}^{\\prime}$ where $t$ denotes the target outcome.\n\n:::\n\nWe can generate samples from $p_{\\theta}(X|y)$ following @grathwohl2020your. In @fig-energy, I have applied the methodology to our synthetic data.\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\nM = ECCCE.ConformalModel(conf_model, mach.fitresult)\n\nniter = 100\nnsamples = 100\n\nplts = []\nfor (i,target) ∈ enumerate(counterfactual_data.y_levels)\n sampler = ECCCE.EnergySampler(M, counterfactual_data, target; niter=niter, nsamples=100)\n Xgen = rand(sampler, nsamples)\n plt = Plots.plot(M, counterfactual_data; target=target, zoom=-3,cbar=false)\n Plots.scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=i,shape=:star,label=\"X|y=$target\")\n push!(plts, plt)\nend\nPlots.plot(plts..., layout=(1,length(plts)), size=(img_height*length(plts),img_height))\n```\n\n::: {.cell-output .cell-output-display execution_count=7}\n{#fig-energy}\n:::\n:::\n\n\nAs an evaluation metric and penalty, we could use the average distance of the counterfactual $x^{\\prime}$ from these generated samples, for example.\n\n### Plausibility\n\nWe propose to define plausibility as follows:\n\n::: {#def-plausible}\n\n## Plausible Counterfactuals\n\nFormally, let $\\mathcal{X}|t$ denote the conditional distribution of samples in the target class. As before, we have $x^{\\prime}\\sim\\mathcal{X}^{\\prime}$, then for $x^{\\prime}$ to be considered a plausible counterfactual, we need: $\\mathcal{X}|t \\approxeq \\mathcal{X}^{\\prime}$.\n\n:::\n\nAs an evaluation metric and penalty, we could use the average distance of the counterfactual $x^{\\prime}$ from (potentially bootstrapped) training samples in the target class, for example.\n\n## Counterfactual Explanations\n\nNext, let's generate counterfactual explanations for our synthetic data. We first wrap our model in a container that makes it compatible with `CounterfactualExplanations.jl`. Then we draw a random sample, determine its predicted label $\\hat{y}$ and choose the opposite label as our target. \n\n::: {.cell execution_count=7}\n``` {.julia .cell-code}\nx = select_factual(counterfactual_data,rand(1:size(counterfactual_data.X,2)))\ny_factual = predict_label(M, counterfactual_data, x)[1]\ntarget = counterfactual_data.y_levels[counterfactual_data.y_levels .!= y_factual][1]\n```\n:::\n\n\nThe generic Conformal Counterfactual Generator penalises the only the set size only:\n\n$$\nx^\\prime = \\arg \\min_{x^\\prime} \\ell(M(x^\\prime),t) + \\lambda \\mathbb{I}_{y^\\prime = t} \\Omega(C_{\\theta}(x;\\tau)) \n$$ {#eq-solution}\n\n::: {.cell execution_count=8}\n\n::: {.cell-output .cell-output-display execution_count=9}\n{#fig-ce}\n:::\n:::\n\n\n## Multi-Class\n\n::: {.cell execution_count=9}\n``` {.julia .cell-code}\ncounterfactual_data = load_multi_class()\n```\n:::\n\n\n::: {.cell execution_count=10}\n``` {.julia .cell-code}\nX = table(permutedims(counterfactual_data.X))\ny = counterfactual_data.output_encoder.labels\n```\n:::\n\n\n::: {.cell execution_count=11}\n\n::: {.cell-output .cell-output-display execution_count=12}\n{#fig-pen-multi}\n:::\n:::\n\n\n::: {.cell execution_count=12}\n\n::: {.cell-output .cell-output-display execution_count=13}\n{#fig-losses-multi}\n:::\n:::\n\n\n::: {.cell execution_count=13}\n\n::: {.cell-output .cell-output-display execution_count=14}\n{#fig-energy-multi}\n:::\n:::\n\n\n::: {.cell execution_count=14}\n``` {.julia .cell-code}\nx = select_factual(counterfactual_data,rand(1:size(counterfactual_data.X,2)))\ny_factual = predict_label(M, counterfactual_data, x)[1]\ntarget = counterfactual_data.y_levels[counterfactual_data.y_levels .!= y_factual][1]\n```\n:::\n\n\n::: {.cell execution_count=15}\n\n::: {.cell-output .cell-output-display execution_count=16}\n{#fig-ce-multi}\n:::\n:::\n\n\n## Benchmarks\n\n::: {.cell execution_count=16}\n``` {.julia .cell-code}\n# Data:\ndatasets = Dict(\n :linearly_separable => load_linearly_separable(),\n :overlapping => load_overlapping(),\n :moons => load_moons(),\n :circles => load_circles(),\n :multi_class => load_multi_class(),\n)\n\n# Untrained Models:\nmodels = Dict(\n :cov75 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.75)),\n :cov80 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.80)),\n :cov90 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.90)),\n :cov99 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.99)),\n)\n```\n:::\n\n\nThen we can simply loop over the datasets and eventually concatenate the results like so:\n\n::: {.cell execution_count=17}\n``` {.julia .cell-code}\nusing CounterfactualExplanations.Evaluation: benchmark\nbmks = []\nmeasures = [\n CounterfactualExplanations.distance,\n ECCCE.distance_from_energy,\n ECCCE.distance_from_targets\n]\nfor (dataname, dataset) in datasets\n bmk = benchmark(\n dataset; \n models=deepcopy(models), \n generators=generators, \n measure=measures,\n suppress_training=false, dataname=dataname,\n n_individuals=10\n )\n push!(bmks, bmk)\nend\nbmk = reduce(vcat, bmks)\n```\n:::\n\n\n::: {.cell execution_count=18}\n``` {.julia .cell-code}\nf(ce) = CounterfactualExplanations.model_evaluation(ce.M, ce.data)\n@chain bmk() begin\n @group_by(model, generator, dataname, variable)\n @select(model, generator, dataname, ce, value)\n @mutate(performance = f(ce))\n @summarize(model=unique(model), generator=unique(generator), dataname=unique(dataname), performace=unique(performance), value=mean(value))\n @ungroup\n @filter(dataname == :multi_class)\n @filter(model == :cov99)\n @filter(variable == \"distance\")\nend\n```\n:::\n\n\n::: {#fig-benchmark .cell execution_count=19}\n\n::: {.cell-output .cell-output-display}\n{#fig-benchmark-1}\n:::\n\n::: {.cell-output .cell-output-display}\n{#fig-benchmark-2}\n:::\n\n::: {.cell-output .cell-output-display}\n{#fig-benchmark-3}\n:::\n\n::: {.cell-output .cell-output-display}\n{#fig-benchmark-4}\n:::\n\n::: {.cell-output .cell-output-display}\n{#fig-benchmark-5}\n:::\n\nBenchmark results for the different generators.\n:::\n\n\n", "supporting": [ "intro_files/figure-html" ], diff --git a/_freeze/notebooks/proposal/execute-results/html.json b/_freeze/notebooks/proposal/execute-results/html.json index 433db5ebe76c7a9899d699b41788475044666c98..e9804de6ee50ab04854e7bc338b7099606456d67 100644 --- a/_freeze/notebooks/proposal/execute-results/html.json +++ b/_freeze/notebooks/proposal/execute-results/html.json @@ -1,7 +1,7 @@ { "hash": "24ab407f04257b00a84f7dcaee456281", "result": { - "markdown": "---\ntitle: High-Fidelity Counterfactual Explanations through Conformal Prediction\nsubtitle: Research Proposal\nabstract: |\n We propose Conformal Counterfactual Explanations: an effortless and rigorous way to produce realistic and faithful Counterfactual Explanations using Conformal Prediction. To address the need for realistic counterfactuals, existing work has primarily relied on separate generative models to learn the data-generating process. While this is an effective way to produce plausible and model-agnostic counterfactual explanations, it not only introduces a significant engineering overhead but also reallocates the task of creating realistic model explanations from the model itself to the generative model. Recent work has shown that there is no need for any of this when working with probabilistic models that explicitly quantify their own uncertainty. Unfortunately, most models used in practice still do not fulfil that basic requirement, in which case we would like to have a way to quantify predictive uncertainty in a post-hoc fashion.\n---\n\n\n\n## Motivation\n\nCounterfactual Explanations are a powerful, flexible and intuitive way to not only explain black-box models but also enable affected individuals to challenge them through the means of Algorithmic Recourse. \n\n### Counterfactual Explanations or Adversarial Examples?\n\nMost state-of-the-art approaches to generating Counterfactual Explanations (CE) rely on gradient descent in the feature space. The key idea is to perturb inputs $x\\in\\mathcal{X}$ into a black-box model $f: \\mathcal{X} \\mapsto \\mathcal{Y}$ in order to change the model output $f(x)$ to some pre-specified target value $t\\in\\mathcal{Y}$. Formally, this boils down to defining some loss function $\\ell(f(x),t)$ and taking gradient steps in the minimizing direction. The so-generated counterfactuals are considered valid as soon as the predicted label matches the target label. A stripped-down counterfactual explanation is therefore little different from an adversarial example. In @fig-adv, for example, generic counterfactual search as in @wachter2017counterfactual has been applied to MNIST data.\n\n\n\n\n\n\n\n{#fig-adv}\n\nThe crucial difference between adversarial examples and counterfactuals is one of intent. While adversarial examples are typically intended to go unnoticed, counterfactuals in the context of Explainable AI are generally sought to be \"plausible\", \"realistic\" or \"feasible\". To fulfil this latter goal, researchers have come up with a myriad of ways. @joshi2019realistic were among the first to suggest that instead of searching counterfactuals in the feature space, we can instead traverse a latent embedding learned by a surrogate generative model. Similarly, @poyiadzi2020face use density ... Finally, @karimi2021algorithmic argues that counterfactuals should comply with the causal model that generates them [CHECK IF WE CAN PHASE THIS LIKE THIS]. Other related approaches include ... All of these different approaches have a common goal: they aim to ensure that the generated counterfactuals comply with the (learned) data-generating process (DGB). \n\n::: {#def-plausible}\n\n## Plausible Counterfactuals\n\nFormally, if $x \\sim \\mathcal{X}$ and for the corresponding counterfactual we have $x^{\\prime}\\sim\\mathcal{X}^{\\prime}$, then for $x^{\\prime}$ to be considered a plausible counterfactual, we need: $\\mathcal{X} \\approxeq \\mathcal{X}^{\\prime}$.\n\n:::\n\nIn the context of Algorithmic Recourse, it makes sense to strive for plausible counterfactuals, since anything else would essentially require individuals to move to out-of-distribution states. But it is worth noting that our ambition to meet this goal, may have implications on our ability to faithfully explain the behaviour of the underlying black-box model (arguably our principal goal). By essentially decoupling the task of learning plausible representations of the data from the model itself, we open ourselves up to vulnerabilities. Using a separate generative model to learn $\\mathcal{X}$, for example, has very serious implications for the generated counterfactuals. @fig-latent compares the results of applying REVISE [@joshi2019realistic] to MNIST data using two different Variational Auto-Encoders: while the counterfactual generated using an expressive (strong) VAE is compelling, the result relying on a less expressive (weak) VAE is not even valid. In this latter case, the decoder step of the VAE fails to yield values in $\\mathcal{X}$ and hence the counterfactual search in the learned latent space is doomed. \n\n\n\n\n\n\n\n{#fig-latent}\n\n> Here it would be nice to have another example where we poison the data going into the generative model to hide biases present in the data (e.g. Boston housing).\n\n- Latent can be manipulated: \n - train biased model\n - train VAE with biased variable removed/attacked (use Boston housing dataset)\n - hypothesis: will generate bias-free explanations\n\n### From Plausible to High-Fidelity Counterfactuals {#sec-fidelity}\n\nIn light of the findings, we propose to generally avoid using surrogate models to learn $\\mathcal{X}$ in the context of Counterfactual Explanations.\n\n::: {#prp-surrogate}\n\n## Avoid Surrogates\n\nSince we are in the business of explaining a black-box model, the task of learning realistic representations of the data should not be reallocated from the model itself to some surrogate model.\n\n:::\n\nIn cases where the use of surrogate models cannot be avoided, we propose to weigh the plausibility of counterfactuals against their fidelity to the black-box model. In the context of Explainable AI, fidelity is defined as describing how an explanation approximates the prediction of the black-box model [@molnar2020interpretable]. Fidelity has become the default metric for evaluating Local Model-Agnostic Models, since they often involve local surrogate models whose predictions need not always match those of the black-box model. \n\nIn the case of Counterfactual Explanations, the concept of fidelity has so far been ignored. This is not altogether surprising, since by construction and design, Counterfactual Explanations work with the predictions of the black-box model directly: as stated above, a counterfactual $x^{\\prime}$ is considered valid if and only if $f(x^{\\prime})=t$, where $t$ denote some target outcome. \n\nDoes fidelity even make sense in the context of CE, and if so, how can we define it? In light of the examples in the previous section, we think it is urgent to introduce a notion of fidelity in this context, that relates to the distributional properties of the generated counterfactuals. In particular, we propose that a high-fidelity counterfactual $x^{\\prime}$ complies with the class-conditional distribution $\\mathcal{X}_{\\theta} = p_{\\theta}(X|y)$ where $\\theta$ denote the black-box model parameters. \n\n::: {#def-fidele}\n\n## High-Fidelity Counterfactuals\n\nLet $\\mathcal{X}_{\\theta}|y = p_{\\theta}(X|y)$ denote the class-conditional distribution of $X$ defined by $\\theta$. Then for $x^{\\prime}$ to be considered a high-fidelity counterfactual, we need: $\\mathcal{X}_{\\theta}|t \\approxeq \\mathcal{X}^{\\prime}$ where $t$ denotes the target outcome.\n\n:::\n\nIn order to assess the fidelity of counterfactuals, we propose the following two-step procedure:\n\n1) Generate samples $X_{\\theta}|y$ and $X^{\\prime}$ from $\\mathcal{X}_{\\theta}|t$ and $\\mathcal{X}^{\\prime}$, respectively.\n2) Compute the Maximum Mean Discrepancy (MMD) between $X_{\\theta}|y$ and $X^{\\prime}$. \n\nIf the computed value is different from zero, we can reject the null-hypothesis of fidelity.\n\n> Two challenges here: 1) implementing the sampling procedure in @grathwohl2020your; 2) it is unclear if MMD is really the right way to measure this. \n\n## Conformal Counterfactual Explanations\n\nIn @sec-fidelity, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (CCE), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models. \n\n### Minimizing Predictive Uncertainty\n\n@schut2021generating demonstrated that the goal of generating realistic (plausible) counterfactuals can also be achieved by seeking counterfactuals that minimize the predictive uncertainty of the underlying black-box model. Similarly, @antoran2020getting ...\n\n- Problem: restricted to Bayesian models.\n- Solution: post-hoc predictive uncertainty quantification. In particular, Conformal Prediction. \n\n### Background on Conformal Prediction\n\n- Distribution-free, model-agnostic and scalable approach to predictive uncertainty quantification.\n- Conformal prediction is instance-based. So is CE. \n- Take any fitted model and turn it into a conformal model using calibration data.\n- Our approach, therefore, relaxes the restriction on the family of black-box models, at the cost of relying on a subset of the data. Arguably, data is often abundant and in most applications practitioners tend to hold out a test data set anyway. \n\n> Does the coverage guarantee carry over to counterfactuals?\n\n### Generating Conformal Counterfactuals\n\nWhile Conformal Prediction has recently grown in popularity, it does introduce a challenge in the context of classification: the predictions of Conformal Classifiers are set-valued and therefore difficult to work with, since they are, for example, non-differentiable. Fortunately, @stutz2022learning introduced carefully designed differentiable loss functions that make it possible to evaluate the performance of conformal predictions in training. We can leverage these recent advances in the context of gradient-based counterfactual search ...\n\n> Challenge: still need to implement these loss functions. \n\n## Experiments\n\n### Research Questions\n\n- Is CP alone enough to ensure realistic counterfactuals?\n- Do counterfactuals improve further as the models get better?\n- Do counterfactuals get more realistic as coverage\n- What happens as we vary coverage and setsize?\n- What happens as we improve the model robustness?\n- What happens as we improve the model's ability to incorporate predictive uncertainty (deep ensemble, laplace)?\n- What happens if we combine with DiCE, ClaPROAR, Gravitational?\n- What about CE robustness to endogenous shifts [@altmeyer2023endogenous]?\n\n- Benchmarking:\n - add PROBE [@pawelczyk2022probabilistically] into the mix.\n - compare travel costs to domain shits.\n\n> Nice to have: What about using Laplace Approximation, then Conformal Prediction? What about using Conformalised Laplace? \n\n## References\n\n", + "markdown": "---\ntitle: High-Fidelity Counterfactual Explanations through Conformal Prediction\nsubtitle: Research Proposal\nabstract: |\n We propose Conformal Counterfactual Explanations: an effortless and rigorous way to produce realistic and faithful Counterfactual Explanations using Conformal Prediction. To address the need for realistic counterfactuals, existing work has primarily relied on separate generative models to learn the data-generating process. While this is an effective way to produce plausible and model-agnostic counterfactual explanations, it not only introduces a significant engineering overhead but also reallocates the task of creating realistic model explanations from the model itself to the generative model. Recent work has shown that there is no need for any of this when working with probabilistic models that explicitly quantify their own uncertainty. Unfortunately, most models used in practice still do not fulfil that basic requirement, in which case we would like to have a way to quantify predictive uncertainty in a post-hoc fashion.\n---\n\n\n\n## Motivation\n\nCounterfactual Explanations are a powerful, flexible and intuitive way to not only explain black-box models but also enable affected individuals to challenge them through the means of Algorithmic Recourse. \n\n### Counterfactual Explanations or Adversarial Examples?\n\nMost state-of-the-art approaches to generating Counterfactual Explanations (CE) rely on gradient descent in the feature space. The key idea is to perturb inputs $x\\in\\mathcal{X}$ into a black-box model $f: \\mathcal{X} \\mapsto \\mathcal{Y}$ in order to change the model output $f(x)$ to some pre-specified target value $t\\in\\mathcal{Y}$. Formally, this boils down to defining some loss function $\\ell(f(x),t)$ and taking gradient steps in the minimizing direction. The so-generated counterfactuals are considered valid as soon as the predicted label matches the target label. A stripped-down counterfactual explanation is therefore little different from an adversarial example. In @fig-adv, for example, generic counterfactual search as in @wachter2017counterfactual has been applied to MNIST data.\n\n\n\n\n\n\n\n{#fig-adv}\n\nThe crucial difference between adversarial examples and counterfactuals is one of intent. While adversarial examples are typically intended to go unnoticed, counterfactuals in the context of Explainable AI are generally sought to be \"plausible\", \"realistic\" or \"feasible\". To fulfil this latter goal, researchers have come up with a myriad of ways. @joshi2019realistic were among the first to suggest that instead of searching counterfactuals in the feature space, we can instead traverse a latent embedding learned by a surrogate generative model. Similarly, @poyiadzi2020face use density ... Finally, @karimi2021algorithmic argues that counterfactuals should comply with the causal model that generates them [CHECK IF WE CAN PHASE THIS LIKE THIS]. Other related approaches include ... All of these different approaches have a common goal: they aim to ensure that the generated counterfactuals comply with the (learned) data-generating process (DGB). \n\n::: {#def-plausible}\n\n## Plausible Counterfactuals\n\nFormally, if $x \\sim \\mathcal{X}$ and for the corresponding counterfactual we have $x^{\\prime}\\sim\\mathcal{X}^{\\prime}$, then for $x^{\\prime}$ to be considered a plausible counterfactual, we need: $\\mathcal{X} \\approxeq \\mathcal{X}^{\\prime}$.\n\n:::\n\nIn the context of Algorithmic Recourse, it makes sense to strive for plausible counterfactuals, since anything else would essentially require individuals to move to out-of-distribution states. But it is worth noting that our ambition to meet this goal, may have implications on our ability to faithfully explain the behaviour of the underlying black-box model (arguably our principal goal). By essentially decoupling the task of learning plausible representations of the data from the model itself, we open ourselves up to vulnerabilities. Using a separate generative model to learn $\\mathcal{X}$, for example, has very serious implications for the generated counterfactuals. @fig-latent compares the results of applying REVISE [@joshi2019realistic] to MNIST data using two different Variational Auto-Encoders: while the counterfactual generated using an expressive (strong) VAE is compelling, the result relying on a less expressive (weak) VAE is not even valid. In this latter case, the decoder step of the VAE fails to yield values in $\\mathcal{X}$ and hence the counterfactual search in the learned latent space is doomed. \n\n\n\n\n\n\n\n{#fig-latent}\n\n> Here it would be nice to have another example where we poison the data going into the generative model to hide biases present in the data (e.g. Boston housing).\n\n- Latent can be manipulated: \n - train biased model\n - train VAE with biased variable removed/attacked (use Boston housing dataset)\n - hypothesis: will generate bias-free explanations\n\n### From Plausible to High-Fidelity Counterfactuals {#sec-fidelity}\n\nIn light of the findings, we propose to generally avoid using surrogate models to learn $\\mathcal{X}$ in the context of Counterfactual Explanations.\n\n::: {#prp-surrogate}\n\n## Avoid Surrogates\n\nSince we are in the business of explaining a black-box model, the task of learning realistic representations of the data should not be reallocated from the model itself to some surrogate model.\n\n:::\n\nIn cases where the use of surrogate models cannot be avoided, we propose to weigh the plausibility of counterfactuals against their fidelity to the black-box model. In the context of Explainable AI, fidelity is defined as describing how an explanation approximates the prediction of the black-box model [@molnar2020interpretable]. Fidelity has become the default metric for evaluating Local Model-Agnostic Models, since they often involve local surrogate models whose predictions need not always match those of the black-box model. \n\nIn the case of Counterfactual Explanations, the concept of fidelity has so far been ignored. This is not altogether surprising, since by construction and design, Counterfactual Explanations work with the predictions of the black-box model directly: as stated above, a counterfactual $x^{\\prime}$ is considered valid if and only if $f(x^{\\prime})=t$, where $t$ denote some target outcome. \n\nDoes fidelity even make sense in the context of CE, and if so, how can we define it? In light of the examples in the previous section, we think it is urgent to introduce a notion of fidelity in this context, that relates to the distributional properties of the generated counterfactuals. In particular, we propose that a high-fidelity counterfactual $x^{\\prime}$ complies with the class-conditional distribution $\\mathcal{X}_{\\theta} = p_{\\theta}(X|y)$ where $\\theta$ denote the black-box model parameters. \n\n::: {#def-fidele}\n\n## High-Fidelity Counterfactuals\n\nLet $\\mathcal{X}_{\\theta}|y = p_{\\theta}(X|y)$ denote the class-conditional distribution of $X$ defined by $\\theta$. Then for $x^{\\prime}$ to be considered a high-fidelity counterfactual, we need: $\\mathcal{X}_{\\theta}|t \\approxeq \\mathcal{X}^{\\prime}$ where $t$ denotes the target outcome.\n\n:::\n\nIn order to assess the fidelity of counterfactuals, we propose the following two-step procedure:\n\n1) Generate samples $X_{\\theta}|y$ and $X^{\\prime}$ from $\\mathcal{X}_{\\theta}|t$ and $\\mathcal{X}^{\\prime}$, respectively.\n2) Compute the Maximum Mean Discrepancy (MMD) between $X_{\\theta}|y$ and $X^{\\prime}$. \n\nIf the computed value is different from zero, we can reject the null-hypothesis of fidelity.\n\n> Two challenges here: 1) implementing the sampling procedure in @grathwohl2020your; 2) it is unclear if MMD is really the right way to measure this. \n\n## Conformal Counterfactual Explanations\n\nIn @sec-fidelity, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (ECCCE), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models. \n\n### Minimizing Predictive Uncertainty\n\n@schut2021generating demonstrated that the goal of generating realistic (plausible) counterfactuals can also be achieved by seeking counterfactuals that minimize the predictive uncertainty of the underlying black-box model. Similarly, @antoran2020getting ...\n\n- Problem: restricted to Bayesian models.\n- Solution: post-hoc predictive uncertainty quantification. In particular, Conformal Prediction. \n\n### Background on Conformal Prediction\n\n- Distribution-free, model-agnostic and scalable approach to predictive uncertainty quantification.\n- Conformal prediction is instance-based. So is CE. \n- Take any fitted model and turn it into a conformal model using calibration data.\n- Our approach, therefore, relaxes the restriction on the family of black-box models, at the cost of relying on a subset of the data. Arguably, data is often abundant and in most applications practitioners tend to hold out a test data set anyway. \n\n> Does the coverage guarantee carry over to counterfactuals?\n\n### Generating Conformal Counterfactuals\n\nWhile Conformal Prediction has recently grown in popularity, it does introduce a challenge in the context of classification: the predictions of Conformal Classifiers are set-valued and therefore difficult to work with, since they are, for example, non-differentiable. Fortunately, @stutz2022learning introduced carefully designed differentiable loss functions that make it possible to evaluate the performance of conformal predictions in training. We can leverage these recent advances in the context of gradient-based counterfactual search ...\n\n> Challenge: still need to implement these loss functions. \n\n## Experiments\n\n### Research Questions\n\n- Is CP alone enough to ensure realistic counterfactuals?\n- Do counterfactuals improve further as the models get better?\n- Do counterfactuals get more realistic as coverage\n- What happens as we vary coverage and setsize?\n- What happens as we improve the model robustness?\n- What happens as we improve the model's ability to incorporate predictive uncertainty (deep ensemble, laplace)?\n- What happens if we combine with DiCE, ClaPROAR, Gravitational?\n- What about CE robustness to endogenous shifts [@altmeyer2023endogenous]?\n\n- Benchmarking:\n - add PROBE [@pawelczyk2022probabilistically] into the mix.\n - compare travel costs to domain shits.\n\n> Nice to have: What about using Laplace Approximation, then Conformal Prediction? What about using Conformalised Laplace? \n\n## References\n\n", "supporting": [ "proposal_files/figure-html" ], diff --git a/_freeze/notebooks/synthetic/execute-results/html.json b/_freeze/notebooks/synthetic/execute-results/html.json index 5770692cb22ebd32f2f297ff262b3853e79676f2..682d0f96f0cf96570eecbd0646392516bcc73e08 100644 --- a/_freeze/notebooks/synthetic/execute-results/html.json +++ b/_freeze/notebooks/synthetic/execute-results/html.json @@ -1,7 +1,7 @@ { "hash": "617bb13e20ec081d43c585fd80675156", "result": { - "markdown": "::: {.cell execution_count=1}\n``` {.julia .cell-code}\ninclude(\"notebooks/setup.jl\")\neval(setup_notebooks);\n```\n:::\n\n\n# Synthetic data\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\n# Data:\ndatasets = Dict(\n :linearly_separable => load_linearly_separable(),\n :overlapping => load_overlapping(),\n :moons => load_moons(),\n :circles => load_circles(),\n :multi_class => load_multi_class(),\n)\n\n# Hyperparameters:\ncvgs = [0.5, 0.75, 0.95]\ntemps = [0.01, 0.1, 1.0]\nΛ = [0.0, 0.1, 1.0, 10.0]\nl2_λ = 0.1\n\n# Classifiers:\nepochs = 250\nlink_fun = relu\nlogreg = NeuralNetworkClassifier(builder=MLJFlux.Linear(σ=link_fun), epochs=epochs)\nmlp = NeuralNetworkClassifier(builder=MLJFlux.MLP(hidden=(32,), σ=link_fun), epochs=epochs)\nensmbl = EnsembleModel(model=mlp, n=5)\nclassifiers = Dict(\n # :logreg => logreg,\n :mlp => mlp,\n # :ensmbl => ensmbl,\n)\n\n# Search parameters:\ntarget = 2\nfactual = 1\nmax_iter = 50\ngradient_tol = 1e-2\nopt = Descent(0.01)\n```\n:::\n\n\n\n\n\n\n::: {.cell execution_count=5}\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n:::\n\n\n## Benchmark\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\n# Benchmark generators:\ngenerators = Dict(\n :wachter => GenericGenerator(opt=opt, λ=l2_λ),\n :revise => REVISEGenerator(opt=opt, λ=l2_λ),\n :greedy => GreedyGenerator(),\n)\n\n# Untrained Models:\nmodels = Dict(Symbol(\"cov$(Int(100*cov))\") => CCE.ConformalModel(conformal_model(mlp; method=:simple_inductive, coverage=cov)) for cov in cvgs)\n\n# Measures:\nmeasures = [\n CounterfactualExplanations.distance,\n CCE.distance_from_energy,\n CCE.distance_from_targets,\n CounterfactualExplanations.validity,\n]\n```\n:::\n\n\n### Single CE\n\n\n\n\n\n::: {.cell execution_count=9}\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n:::\n\n\n### Full Benchmark\n\n::: {.cell execution_count=10}\n``` {.julia .cell-code}\nbmks = []\nfor (dataname, dataset) in datasets\n for λ in Λ, temp in temps\n _generators = deepcopy(generators)\n _generators[:cce] = CCEGenerator(temp=temp, λ=[l2_λ,λ], opt=opt)\n _generators[:energy] = CCE.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt)\n _generators[:target] = CCE.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt)\n bmk = benchmark(\n dataset; \n models=deepcopy(models), \n generators=_generators, \n measure=measures,\n suppress_training=false, dataname=dataname,\n n_individuals=5,\n initialization=:identity,\n )\n bmk.evaluation.λ .= λ\n bmk.evaluation.temperature .= temp\n push!(bmks, bmk)\n end\nend\nbmk = reduce(vcat, bmks)\n```\n:::\n\n\n::: {.cell execution_count=11}\n``` {.julia .cell-code}\nCSV.write(joinpath(output_path, \"synthetic_benchmark.csv\"), bmk())\n```\n:::\n\n\n::: {.cell execution_count=12}\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n:::\n\n\n", + "markdown": "::: {.cell execution_count=1}\n``` {.julia .cell-code}\ninclude(\"notebooks/setup.jl\")\neval(setup_notebooks);\n```\n:::\n\n\n# Synthetic data\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\n# Data:\ndatasets = Dict(\n :linearly_separable => load_linearly_separable(),\n :overlapping => load_overlapping(),\n :moons => load_moons(),\n :circles => load_circles(),\n :multi_class => load_multi_class(),\n)\n\n# Hyperparameters:\ncvgs = [0.5, 0.75, 0.95]\ntemps = [0.01, 0.1, 1.0]\nΛ = [0.0, 0.1, 1.0, 10.0]\nl2_λ = 0.1\n\n# Classifiers:\nepochs = 250\nlink_fun = relu\nlogreg = NeuralNetworkClassifier(builder=MLJFlux.Linear(σ=link_fun), epochs=epochs)\nmlp = NeuralNetworkClassifier(builder=MLJFlux.MLP(hidden=(32,), σ=link_fun), epochs=epochs)\nensmbl = EnsembleModel(model=mlp, n=5)\nclassifiers = Dict(\n # :logreg => logreg,\n :mlp => mlp,\n # :ensmbl => ensmbl,\n)\n\n# Search parameters:\ntarget = 2\nfactual = 1\nmax_iter = 50\ngradient_tol = 1e-2\nopt = Descent(0.01)\n```\n:::\n\n\n\n\n\n\n::: {.cell execution_count=5}\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n:::\n\n\n## Benchmark\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\n# Benchmark generators:\ngenerators = Dict(\n :wachter => GenericGenerator(opt=opt, λ=l2_λ),\n :revise => REVISEGenerator(opt=opt, λ=l2_λ),\n :greedy => GreedyGenerator(),\n)\n\n# Untrained Models:\nmodels = Dict(Symbol(\"cov$(Int(100*cov))\") => ECCCE.ConformalModel(conformal_model(mlp; method=:simple_inductive, coverage=cov)) for cov in cvgs)\n\n# Measures:\nmeasures = [\n CounterfactualExplanations.distance,\n ECCCE.distance_from_energy,\n ECCCE.distance_from_targets,\n CounterfactualExplanations.validity,\n]\n```\n:::\n\n\n### Single CE\n\n\n\n\n\n::: {.cell execution_count=9}\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n:::\n\n\n### Full Benchmark\n\n::: {.cell execution_count=10}\n``` {.julia .cell-code}\nbmks = []\nfor (dataname, dataset) in datasets\n for λ in Λ, temp in temps\n _generators = deepcopy(generators)\n _generators[:cce] = ECCCEGenerator(temp=temp, λ=[l2_λ,λ], opt=opt)\n _generators[:energy] = ECCCE.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt)\n _generators[:target] = ECCCE.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt)\n bmk = benchmark(\n dataset; \n models=deepcopy(models), \n generators=_generators, \n measure=measures,\n suppress_training=false, dataname=dataname,\n n_individuals=5,\n initialization=:identity,\n )\n bmk.evaluation.λ .= λ\n bmk.evaluation.temperature .= temp\n push!(bmks, bmk)\n end\nend\nbmk = reduce(vcat, bmks)\n```\n:::\n\n\n::: {.cell execution_count=11}\n``` {.julia .cell-code}\nCSV.write(joinpath(output_path, \"synthetic_benchmark.csv\"), bmk())\n```\n:::\n\n\n::: {.cell execution_count=12}\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n\n::: {.cell-output .cell-output-display}\n{}\n:::\n:::\n\n\n", "supporting": [ "synthetic_files" ], diff --git a/docs/notebooks/intro.html b/docs/notebooks/intro.html index 6ecdadd7580a202be3498dab882f53db17b9eb69..6a00699b32a2ebbece143893998f4e1ecb2e9aa5 100644 --- a/docs/notebooks/intro.html +++ b/docs/notebooks/intro.html @@ -351,14 +351,14 @@ C(X_{\text{test}})=\{y:s(X_{\text{test}},y) \le \hat{q}\} </div> <p>We can generate samples from <span class="math inline">\(p_{\theta}(X|y)\)</span> following <span class="citation" data-cites="grathwohl2020your">Grathwohl et al. (<a href="references.html#ref-grathwohl2020your" role="doc-biblioref">2020</a>)</span>. In <a href="#fig-energy">Figure <span>2.2</span></a>, I have applied the methodology to our synthetic data.</p> <div class="cell" data-execution_count="6"> -<div class="sourceCode cell-code" id="cb5"><pre class="sourceCode julia code-with-copy"><code class="sourceCode julia"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a>M <span class="op">=</span> CCE.<span class="fu">ConformalModel</span>(conf_model, mach.fitresult)</span> +<div class="sourceCode cell-code" id="cb5"><pre class="sourceCode julia code-with-copy"><code class="sourceCode julia"><span id="cb5-1"><a href="#cb5-1" aria-hidden="true" tabindex="-1"></a>M <span class="op">=</span> ECCCE.<span class="fu">ConformalModel</span>(conf_model, mach.fitresult)</span> <span id="cb5-2"><a href="#cb5-2" aria-hidden="true" tabindex="-1"></a></span> <span id="cb5-3"><a href="#cb5-3" aria-hidden="true" tabindex="-1"></a>niter <span class="op">=</span> <span class="fl">100</span></span> <span id="cb5-4"><a href="#cb5-4" aria-hidden="true" tabindex="-1"></a>nsamples <span class="op">=</span> <span class="fl">100</span></span> <span id="cb5-5"><a href="#cb5-5" aria-hidden="true" tabindex="-1"></a></span> <span id="cb5-6"><a href="#cb5-6" aria-hidden="true" tabindex="-1"></a>plts <span class="op">=</span> []</span> <span id="cb5-7"><a href="#cb5-7" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> (i,target) <span class="op">∈</span> <span class="fu">enumerate</span>(counterfactual_data.y_levels)</span> -<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a> sampler <span class="op">=</span> CCE.<span class="fu">EnergySampler</span>(M, counterfactual_data, target; niter<span class="op">=</span>niter, nsamples<span class="op">=</span><span class="fl">100</span>)</span> +<span id="cb5-8"><a href="#cb5-8" aria-hidden="true" tabindex="-1"></a> sampler <span class="op">=</span> ECCCE.<span class="fu">EnergySampler</span>(M, counterfactual_data, target; niter<span class="op">=</span>niter, nsamples<span class="op">=</span><span class="fl">100</span>)</span> <span id="cb5-9"><a href="#cb5-9" aria-hidden="true" tabindex="-1"></a> Xgen <span class="op">=</span> <span class="fu">rand</span>(sampler, nsamples)</span> <span id="cb5-10"><a href="#cb5-10" aria-hidden="true" tabindex="-1"></a> plt <span class="op">=</span> Plots.<span class="fu">plot</span>(M, counterfactual_data; target<span class="op">=</span>target, zoom<span class="op">=-</span><span class="fl">3</span>,cbar<span class="op">=</span><span class="cn">false</span>)</span> <span id="cb5-11"><a href="#cb5-11" aria-hidden="true" tabindex="-1"></a> Plots.<span class="fu">scatter!</span>(Xgen[<span class="fl">1</span>,<span class="op">:</span>],Xgen[<span class="fl">2</span>,<span class="op">:</span>],alpha<span class="op">=</span><span class="fl">0.5</span>,color<span class="op">=</span>i,shape<span class="op">=:</span>star,label<span class="op">=</span><span class="st">"X|y=</span><span class="sc">$</span>target<span class="st">"</span>)</span> @@ -477,10 +477,10 @@ x^\prime = \arg \min_{x^\prime} \ell(M(x^\prime),t) + \lambda \mathbb{I}_{y^\pr <span id="cb10-9"><a href="#cb10-9" aria-hidden="true" tabindex="-1"></a></span> <span id="cb10-10"><a href="#cb10-10" aria-hidden="true" tabindex="-1"></a><span class="co"># Untrained Models:</span></span> <span id="cb10-11"><a href="#cb10-11" aria-hidden="true" tabindex="-1"></a>models <span class="op">=</span> <span class="fu">Dict</span>(</span> -<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a> <span class="op">:</span>cov75 <span class="op">=></span> CCE.<span class="fu">ConformalModel</span>(<span class="fu">conformal_model</span>(clf; method<span class="op">=:</span>simple_inductive, coverage<span class="op">=</span><span class="fl">0.75</span>)),</span> -<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a> <span class="op">:</span>cov80 <span class="op">=></span> CCE.<span class="fu">ConformalModel</span>(<span class="fu">conformal_model</span>(clf; method<span class="op">=:</span>simple_inductive, coverage<span class="op">=</span><span class="fl">0.80</span>)),</span> -<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a> <span class="op">:</span>cov90 <span class="op">=></span> CCE.<span class="fu">ConformalModel</span>(<span class="fu">conformal_model</span>(clf; method<span class="op">=:</span>simple_inductive, coverage<span class="op">=</span><span class="fl">0.90</span>)),</span> -<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a> <span class="op">:</span>cov99 <span class="op">=></span> CCE.<span class="fu">ConformalModel</span>(<span class="fu">conformal_model</span>(clf; method<span class="op">=:</span>simple_inductive, coverage<span class="op">=</span><span class="fl">0.99</span>)),</span> +<span id="cb10-12"><a href="#cb10-12" aria-hidden="true" tabindex="-1"></a> <span class="op">:</span>cov75 <span class="op">=></span> ECCCE.<span class="fu">ConformalModel</span>(<span class="fu">conformal_model</span>(clf; method<span class="op">=:</span>simple_inductive, coverage<span class="op">=</span><span class="fl">0.75</span>)),</span> +<span id="cb10-13"><a href="#cb10-13" aria-hidden="true" tabindex="-1"></a> <span class="op">:</span>cov80 <span class="op">=></span> ECCCE.<span class="fu">ConformalModel</span>(<span class="fu">conformal_model</span>(clf; method<span class="op">=:</span>simple_inductive, coverage<span class="op">=</span><span class="fl">0.80</span>)),</span> +<span id="cb10-14"><a href="#cb10-14" aria-hidden="true" tabindex="-1"></a> <span class="op">:</span>cov90 <span class="op">=></span> ECCCE.<span class="fu">ConformalModel</span>(<span class="fu">conformal_model</span>(clf; method<span class="op">=:</span>simple_inductive, coverage<span class="op">=</span><span class="fl">0.90</span>)),</span> +<span id="cb10-15"><a href="#cb10-15" aria-hidden="true" tabindex="-1"></a> <span class="op">:</span>cov99 <span class="op">=></span> ECCCE.<span class="fu">ConformalModel</span>(<span class="fu">conformal_model</span>(clf; method<span class="op">=:</span>simple_inductive, coverage<span class="op">=</span><span class="fl">0.99</span>)),</span> <span id="cb10-16"><a href="#cb10-16" aria-hidden="true" tabindex="-1"></a>)</span></code><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></pre></div> </div> <p>Then we can simply loop over the datasets and eventually concatenate the results like so:</p> @@ -489,8 +489,8 @@ x^\prime = \arg \min_{x^\prime} \ell(M(x^\prime),t) + \lambda \mathbb{I}_{y^\pr <span id="cb11-2"><a href="#cb11-2" aria-hidden="true" tabindex="-1"></a>bmks <span class="op">=</span> []</span> <span id="cb11-3"><a href="#cb11-3" aria-hidden="true" tabindex="-1"></a>measures <span class="op">=</span> [</span> <span id="cb11-4"><a href="#cb11-4" aria-hidden="true" tabindex="-1"></a> CounterfactualExplanations.distance,</span> -<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a> CCE.distance_from_energy,</span> -<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a> CCE.distance_from_targets</span> +<span id="cb11-5"><a href="#cb11-5" aria-hidden="true" tabindex="-1"></a> ECCCE.distance_from_energy,</span> +<span id="cb11-6"><a href="#cb11-6" aria-hidden="true" tabindex="-1"></a> ECCCE.distance_from_targets</span> <span id="cb11-7"><a href="#cb11-7" aria-hidden="true" tabindex="-1"></a>]</span> <span id="cb11-8"><a href="#cb11-8" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> (dataname, dataset) <span class="kw">in</span> datasets</span> <span id="cb11-9"><a href="#cb11-9" aria-hidden="true" tabindex="-1"></a> bmk <span class="op">=</span> <span class="fu">benchmark</span>(</span> diff --git a/docs/notebooks/proposal.html b/docs/notebooks/proposal.html index 958533e38ee06a593cad80431ab4b942cafc39af..4d65db90bd6789a0b1962f333355b8d6838870cc 100644 --- a/docs/notebooks/proposal.html +++ b/docs/notebooks/proposal.html @@ -253,7 +253,7 @@ div.csl-indent { </section> <section id="conformal-counterfactual-explanations" class="level2" data-number="1.2"> <h2 data-number="1.2" class="anchored" data-anchor-id="conformal-counterfactual-explanations"><span class="header-section-number">1.2</span> Conformal Counterfactual Explanations</h2> -<p>In <a href="#sec-fidelity"><span>Section 1.1.2</span></a>, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (CCE), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models.</p> +<p>In <a href="#sec-fidelity"><span>Section 1.1.2</span></a>, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (ECCCE), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models.</p> <section id="minimizing-predictive-uncertainty" class="level3" data-number="1.2.1"> <h3 data-number="1.2.1" class="anchored" data-anchor-id="minimizing-predictive-uncertainty"><span class="header-section-number">1.2.1</span> Minimizing Predictive Uncertainty</h3> <p><span class="citation" data-cites="schut2021generating">Schut et al. (<a href="references.html#ref-schut2021generating" role="doc-biblioref">2021</a>)</span> demonstrated that the goal of generating realistic (plausible) counterfactuals can also be achieved by seeking counterfactuals that minimize the predictive uncertainty of the underlying black-box model. Similarly, <span class="citation" data-cites="antoran2020getting">Antorán et al. (<a href="references.html#ref-antoran2020getting" role="doc-biblioref">2020</a>)</span> …</p> diff --git a/docs/notebooks/synthetic.html b/docs/notebooks/synthetic.html index db488979c89dad5e252b77f4985e9a60b28caaae..ec5ffcc29748f3be384d57a126c744d09b357a8b 100644 --- a/docs/notebooks/synthetic.html +++ b/docs/notebooks/synthetic.html @@ -330,13 +330,13 @@ code span.wa { color: #60a0b0; font-weight: bold; font-style: italic; } /* Warni <span id="cb3-6"><a href="#cb3-6" aria-hidden="true" tabindex="-1"></a>)</span> <span id="cb3-7"><a href="#cb3-7" aria-hidden="true" tabindex="-1"></a></span> <span id="cb3-8"><a href="#cb3-8" aria-hidden="true" tabindex="-1"></a><span class="co"># Untrained Models:</span></span> -<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>models <span class="op">=</span> <span class="fu">Dict</span>(<span class="fu">Symbol</span>(<span class="st">"cov</span><span class="sc">$</span>(<span class="fu">Int</span>(<span class="fl">100</span><span class="op">*</span>cov))<span class="st">"</span>) <span class="op">=></span> CCE.<span class="fu">ConformalModel</span>(<span class="fu">conformal_model</span>(mlp; method<span class="op">=:</span>simple_inductive, coverage<span class="op">=</span>cov)) <span class="cf">for</span> cov <span class="kw">in</span> cvgs)</span> +<span id="cb3-9"><a href="#cb3-9" aria-hidden="true" tabindex="-1"></a>models <span class="op">=</span> <span class="fu">Dict</span>(<span class="fu">Symbol</span>(<span class="st">"cov</span><span class="sc">$</span>(<span class="fu">Int</span>(<span class="fl">100</span><span class="op">*</span>cov))<span class="st">"</span>) <span class="op">=></span> ECCCE.<span class="fu">ConformalModel</span>(<span class="fu">conformal_model</span>(mlp; method<span class="op">=:</span>simple_inductive, coverage<span class="op">=</span>cov)) <span class="cf">for</span> cov <span class="kw">in</span> cvgs)</span> <span id="cb3-10"><a href="#cb3-10" aria-hidden="true" tabindex="-1"></a></span> <span id="cb3-11"><a href="#cb3-11" aria-hidden="true" tabindex="-1"></a><span class="co"># Measures:</span></span> <span id="cb3-12"><a href="#cb3-12" aria-hidden="true" tabindex="-1"></a>measures <span class="op">=</span> [</span> <span id="cb3-13"><a href="#cb3-13" aria-hidden="true" tabindex="-1"></a> CounterfactualExplanations.distance,</span> -<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a> CCE.distance_from_energy,</span> -<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a> CCE.distance_from_targets,</span> +<span id="cb3-14"><a href="#cb3-14" aria-hidden="true" tabindex="-1"></a> ECCCE.distance_from_energy,</span> +<span id="cb3-15"><a href="#cb3-15" aria-hidden="true" tabindex="-1"></a> ECCCE.distance_from_targets,</span> <span id="cb3-16"><a href="#cb3-16" aria-hidden="true" tabindex="-1"></a> CounterfactualExplanations.validity,</span> <span id="cb3-17"><a href="#cb3-17" aria-hidden="true" tabindex="-1"></a>]</span></code><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></pre></div> </div> @@ -397,9 +397,9 @@ code span.wa { color: #60a0b0; font-weight: bold; font-style: italic; } /* Warni <span id="cb4-2"><a href="#cb4-2" aria-hidden="true" tabindex="-1"></a><span class="cf">for</span> (dataname, dataset) <span class="kw">in</span> datasets</span> <span id="cb4-3"><a href="#cb4-3" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> λ <span class="kw">in</span> Λ, temp <span class="kw">in</span> temps</span> <span id="cb4-4"><a href="#cb4-4" aria-hidden="true" tabindex="-1"></a> _generators <span class="op">=</span> <span class="fu">deepcopy</span>(generators)</span> -<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a> _generators[<span class="op">:</span>cce] <span class="op">=</span> <span class="fu">CCEGenerator</span>(temp<span class="op">=</span>temp, λ<span class="op">=</span>[l2_λ,λ], opt<span class="op">=</span>opt)</span> -<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a> _generators[<span class="op">:</span>energy] <span class="op">=</span> CCE.<span class="fu">EnergyDrivenGenerator</span>(λ<span class="op">=</span>[l2_λ,λ], opt<span class="op">=</span>opt)</span> -<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a> _generators[<span class="op">:</span>target] <span class="op">=</span> CCE.<span class="fu">TargetDrivenGenerator</span>(λ<span class="op">=</span>[l2_λ,λ], opt<span class="op">=</span>opt)</span> +<span id="cb4-5"><a href="#cb4-5" aria-hidden="true" tabindex="-1"></a> _generators[<span class="op">:</span>cce] <span class="op">=</span> <span class="fu">ECCCEGenerator</span>(temp<span class="op">=</span>temp, λ<span class="op">=</span>[l2_λ,λ], opt<span class="op">=</span>opt)</span> +<span id="cb4-6"><a href="#cb4-6" aria-hidden="true" tabindex="-1"></a> _generators[<span class="op">:</span>energy] <span class="op">=</span> ECCCE.<span class="fu">EnergyDrivenGenerator</span>(λ<span class="op">=</span>[l2_λ,λ], opt<span class="op">=</span>opt)</span> +<span id="cb4-7"><a href="#cb4-7" aria-hidden="true" tabindex="-1"></a> _generators[<span class="op">:</span>target] <span class="op">=</span> ECCCE.<span class="fu">TargetDrivenGenerator</span>(λ<span class="op">=</span>[l2_λ,λ], opt<span class="op">=</span>opt)</span> <span id="cb4-8"><a href="#cb4-8" aria-hidden="true" tabindex="-1"></a> bmk <span class="op">=</span> <span class="fu">benchmark</span>(</span> <span id="cb4-9"><a href="#cb4-9" aria-hidden="true" tabindex="-1"></a> dataset; </span> <span id="cb4-10"><a href="#cb4-10" aria-hidden="true" tabindex="-1"></a> models<span class="op">=</span><span class="fu">deepcopy</span>(models), </span> diff --git a/docs/search.json b/docs/search.json index 8413a2255dc87ccd21f3189b9376833df67cdd2f..c097c74a608f9bbb42131f8a9e2c671c696bc0cc 100644 --- a/docs/search.json +++ b/docs/search.json @@ -18,7 +18,7 @@ "href": "notebooks/proposal.html#conformal-counterfactual-explanations", "title": "1 High-Fidelity Counterfactual Explanations through Conformal Prediction", "section": "1.2 Conformal Counterfactual Explanations", - "text": "1.2 Conformal Counterfactual Explanations\nIn Section 1.1.2, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (CCE), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models.\n\n1.2.1 Minimizing Predictive Uncertainty\nSchut et al. (2021) demonstrated that the goal of generating realistic (plausible) counterfactuals can also be achieved by seeking counterfactuals that minimize the predictive uncertainty of the underlying black-box model. Similarly, Antorán et al. (2020) …\n\nProblem: restricted to Bayesian models.\nSolution: post-hoc predictive uncertainty quantification. In particular, Conformal Prediction.\n\n\n\n1.2.2 Background on Conformal Prediction\n\nDistribution-free, model-agnostic and scalable approach to predictive uncertainty quantification.\nConformal prediction is instance-based. So is CE.\nTake any fitted model and turn it into a conformal model using calibration data.\nOur approach, therefore, relaxes the restriction on the family of black-box models, at the cost of relying on a subset of the data. Arguably, data is often abundant and in most applications practitioners tend to hold out a test data set anyway.\n\n\nDoes the coverage guarantee carry over to counterfactuals?\n\n\n\n1.2.3 Generating Conformal Counterfactuals\nWhile Conformal Prediction has recently grown in popularity, it does introduce a challenge in the context of classification: the predictions of Conformal Classifiers are set-valued and therefore difficult to work with, since they are, for example, non-differentiable. Fortunately, Stutz et al. (2022) introduced carefully designed differentiable loss functions that make it possible to evaluate the performance of conformal predictions in training. We can leverage these recent advances in the context of gradient-based counterfactual search …\n\nChallenge: still need to implement these loss functions." + "text": "1.2 Conformal Counterfactual Explanations\nIn Section 1.1.2, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (ECCCE), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models.\n\n1.2.1 Minimizing Predictive Uncertainty\nSchut et al. (2021) demonstrated that the goal of generating realistic (plausible) counterfactuals can also be achieved by seeking counterfactuals that minimize the predictive uncertainty of the underlying black-box model. Similarly, Antorán et al. (2020) …\n\nProblem: restricted to Bayesian models.\nSolution: post-hoc predictive uncertainty quantification. In particular, Conformal Prediction.\n\n\n\n1.2.2 Background on Conformal Prediction\n\nDistribution-free, model-agnostic and scalable approach to predictive uncertainty quantification.\nConformal prediction is instance-based. So is CE.\nTake any fitted model and turn it into a conformal model using calibration data.\nOur approach, therefore, relaxes the restriction on the family of black-box models, at the cost of relying on a subset of the data. Arguably, data is often abundant and in most applications practitioners tend to hold out a test data set anyway.\n\n\nDoes the coverage guarantee carry over to counterfactuals?\n\n\n\n1.2.3 Generating Conformal Counterfactuals\nWhile Conformal Prediction has recently grown in popularity, it does introduce a challenge in the context of classification: the predictions of Conformal Classifiers are set-valued and therefore difficult to work with, since they are, for example, non-differentiable. Fortunately, Stutz et al. (2022) introduced carefully designed differentiable loss functions that make it possible to evaluate the performance of conformal predictions in training. We can leverage these recent advances in the context of gradient-based counterfactual search …\n\nChallenge: still need to implement these loss functions." }, { "objectID": "notebooks/proposal.html#experiments", @@ -60,7 +60,7 @@ "href": "notebooks/intro.html#fidelity-and-plausibility", "title": "2 ConformalGenerator", "section": "2.4 Fidelity and Plausibility", - "text": "2.4 Fidelity and Plausibility\nThe main evaluation criteria we are interested in are fidelity and plausibility. Interestingly, we could also consider using these measures as penalties in the counterfactual search.\n\n2.4.1 Fidelity\nWe propose to define fidelity as follows:\n\nDefinition 2.1 (High-Fidelity Counterfactuals) Let \\(\\mathcal{X}_{\\theta}|y = p_{\\theta}(X|y)\\) denote the class-conditional distribution of \\(X\\) defined by \\(\\theta\\). Then for \\(x^{\\prime}\\) to be considered a high-fidelity counterfactual, we need: \\(\\mathcal{X}_{\\theta}|t \\approxeq \\mathcal{X}^{\\prime}\\) where \\(t\\) denotes the target outcome.\n\nWe can generate samples from \\(p_{\\theta}(X|y)\\) following Grathwohl et al. (2020). In Figure 2.2, I have applied the methodology to our synthetic data.\n\nM = CCE.ConformalModel(conf_model, mach.fitresult)\n\nniter = 100\nnsamples = 100\n\nplts = []\nfor (i,target) ∈ enumerate(counterfactual_data.y_levels)\n sampler = CCE.EnergySampler(M, counterfactual_data, target; niter=niter, nsamples=100)\n Xgen = rand(sampler, nsamples)\n plt = Plots.plot(M, counterfactual_data; target=target, zoom=-3,cbar=false)\n Plots.scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=i,shape=:star,label=\"X|y=$target\")\n push!(plts, plt)\nend\nPlots.plot(plts..., layout=(1,length(plts)), size=(img_height*length(plts),img_height))\n\n\n\n\nFigure 2.2: Energy-based conditional samples.\n\n\n\n\nAs an evaluation metric and penalty, we could use the average distance of the counterfactual \\(x^{\\prime}\\) from these generated samples, for example.\n\n\n2.4.2 Plausibility\nWe propose to define plausibility as follows:\n\nDefinition 2.2 (Plausible Counterfactuals) Formally, let \\(\\mathcal{X}|t\\) denote the conditional distribution of samples in the target class. As before, we have \\(x^{\\prime}\\sim\\mathcal{X}^{\\prime}\\), then for \\(x^{\\prime}\\) to be considered a plausible counterfactual, we need: \\(\\mathcal{X}|t \\approxeq \\mathcal{X}^{\\prime}\\).\n\nAs an evaluation metric and penalty, we could use the average distance of the counterfactual \\(x^{\\prime}\\) from (potentially bootstrapped) training samples in the target class, for example." + "text": "2.4 Fidelity and Plausibility\nThe main evaluation criteria we are interested in are fidelity and plausibility. Interestingly, we could also consider using these measures as penalties in the counterfactual search.\n\n2.4.1 Fidelity\nWe propose to define fidelity as follows:\n\nDefinition 2.1 (High-Fidelity Counterfactuals) Let \\(\\mathcal{X}_{\\theta}|y = p_{\\theta}(X|y)\\) denote the class-conditional distribution of \\(X\\) defined by \\(\\theta\\). Then for \\(x^{\\prime}\\) to be considered a high-fidelity counterfactual, we need: \\(\\mathcal{X}_{\\theta}|t \\approxeq \\mathcal{X}^{\\prime}\\) where \\(t\\) denotes the target outcome.\n\nWe can generate samples from \\(p_{\\theta}(X|y)\\) following Grathwohl et al. (2020). In Figure 2.2, I have applied the methodology to our synthetic data.\n\nM = ECCCE.ConformalModel(conf_model, mach.fitresult)\n\nniter = 100\nnsamples = 100\n\nplts = []\nfor (i,target) ∈ enumerate(counterfactual_data.y_levels)\n sampler = ECCCE.EnergySampler(M, counterfactual_data, target; niter=niter, nsamples=100)\n Xgen = rand(sampler, nsamples)\n plt = Plots.plot(M, counterfactual_data; target=target, zoom=-3,cbar=false)\n Plots.scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=i,shape=:star,label=\"X|y=$target\")\n push!(plts, plt)\nend\nPlots.plot(plts..., layout=(1,length(plts)), size=(img_height*length(plts),img_height))\n\n\n\n\nFigure 2.2: Energy-based conditional samples.\n\n\n\n\nAs an evaluation metric and penalty, we could use the average distance of the counterfactual \\(x^{\\prime}\\) from these generated samples, for example.\n\n\n2.4.2 Plausibility\nWe propose to define plausibility as follows:\n\nDefinition 2.2 (Plausible Counterfactuals) Formally, let \\(\\mathcal{X}|t\\) denote the conditional distribution of samples in the target class. As before, we have \\(x^{\\prime}\\sim\\mathcal{X}^{\\prime}\\), then for \\(x^{\\prime}\\) to be considered a plausible counterfactual, we need: \\(\\mathcal{X}|t \\approxeq \\mathcal{X}^{\\prime}\\).\n\nAs an evaluation metric and penalty, we could use the average distance of the counterfactual \\(x^{\\prime}\\) from (potentially bootstrapped) training samples in the target class, for example." }, { "objectID": "notebooks/intro.html#counterfactual-explanations", @@ -81,14 +81,14 @@ "href": "notebooks/intro.html#benchmarks", "title": "2 ConformalGenerator", "section": "2.7 Benchmarks", - "text": "2.7 Benchmarks\n\n# Data:\ndatasets = Dict(\n :linearly_separable => load_linearly_separable(),\n :overlapping => load_overlapping(),\n :moons => load_moons(),\n :circles => load_circles(),\n :multi_class => load_multi_class(),\n)\n\n# Untrained Models:\nmodels = Dict(\n :cov75 => CCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.75)),\n :cov80 => CCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.80)),\n :cov90 => CCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.90)),\n :cov99 => CCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.99)),\n)\n\nThen we can simply loop over the datasets and eventually concatenate the results like so:\n\nusing CounterfactualExplanations.Evaluation: benchmark\nbmks = []\nmeasures = [\n CounterfactualExplanations.distance,\n CCE.distance_from_energy,\n CCE.distance_from_targets\n]\nfor (dataname, dataset) in datasets\n bmk = benchmark(\n dataset; \n models=deepcopy(models), \n generators=generators, \n measure=measures,\n suppress_training=false, dataname=dataname,\n n_individuals=10\n )\n push!(bmks, bmk)\nend\nbmk = reduce(vcat, bmks)\n\n\nf(ce) = CounterfactualExplanations.model_evaluation(ce.M, ce.data)\n@chain bmk() begin\n @group_by(model, generator, dataname, variable)\n @select(model, generator, dataname, ce, value)\n @mutate(performance = f(ce))\n @summarize(model=unique(model), generator=unique(generator), dataname=unique(dataname), performace=unique(performance), value=mean(value))\n @ungroup\n @filter(dataname == :multi_class)\n @filter(model == :cov99)\n @filter(variable == \"distance\")\nend\n\n\n\n\n\n\n\n(a) Circles.\n\n\n\n\n\n\n\n(b) Linearly Separable.\n\n\n\n\n\n\n\n(c) Moons.\n\n\n\n\n\n\n\n(d) Multi-class.\n\n\n\n\n\n\n\n(e) Overlapping.\n\n\n\nFigure 2.8: Benchmark results for the different generators.\n\n\n\n\n\n\nGrathwohl, Will, Kuan-Chieh Wang, Joern-Henrik Jacobsen, David Duvenaud, Mohammad Norouzi, and Kevin Swersky. 2020. “Your Classifier Is Secretly an Energy Based Model and You Should Treat It Like One.†In. https://openreview.net/forum?id=Hkxzx0NtDB.\n\n\nSchut, Lisa, Oscar Key, Rory Mc Grath, Luca Costabello, Bogdan Sacaleanu, Yarin Gal, et al. 2021. “Generating Interpretable Counterfactual Explanations By Implicit Minimisation of Epistemic and Aleatoric Uncertainties.†In International Conference on Artificial Intelligence and Statistics, 1756–64. PMLR.\n\n\nStutz, David, Krishnamurthy Dj Dvijotham, Ali Taylan Cemgil, and Arnaud Doucet. 2022. “Learning Optimal Conformal Classifiers.†In. https://openreview.net/forum?id=t8O-4LKFVx." + "text": "2.7 Benchmarks\n\n# Data:\ndatasets = Dict(\n :linearly_separable => load_linearly_separable(),\n :overlapping => load_overlapping(),\n :moons => load_moons(),\n :circles => load_circles(),\n :multi_class => load_multi_class(),\n)\n\n# Untrained Models:\nmodels = Dict(\n :cov75 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.75)),\n :cov80 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.80)),\n :cov90 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.90)),\n :cov99 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.99)),\n)\n\nThen we can simply loop over the datasets and eventually concatenate the results like so:\n\nusing CounterfactualExplanations.Evaluation: benchmark\nbmks = []\nmeasures = [\n CounterfactualExplanations.distance,\n ECCCE.distance_from_energy,\n ECCCE.distance_from_targets\n]\nfor (dataname, dataset) in datasets\n bmk = benchmark(\n dataset; \n models=deepcopy(models), \n generators=generators, \n measure=measures,\n suppress_training=false, dataname=dataname,\n n_individuals=10\n )\n push!(bmks, bmk)\nend\nbmk = reduce(vcat, bmks)\n\n\nf(ce) = CounterfactualExplanations.model_evaluation(ce.M, ce.data)\n@chain bmk() begin\n @group_by(model, generator, dataname, variable)\n @select(model, generator, dataname, ce, value)\n @mutate(performance = f(ce))\n @summarize(model=unique(model), generator=unique(generator), dataname=unique(dataname), performace=unique(performance), value=mean(value))\n @ungroup\n @filter(dataname == :multi_class)\n @filter(model == :cov99)\n @filter(variable == \"distance\")\nend\n\n\n\n\n\n\n\n(a) Circles.\n\n\n\n\n\n\n\n(b) Linearly Separable.\n\n\n\n\n\n\n\n(c) Moons.\n\n\n\n\n\n\n\n(d) Multi-class.\n\n\n\n\n\n\n\n(e) Overlapping.\n\n\n\nFigure 2.8: Benchmark results for the different generators.\n\n\n\n\n\n\nGrathwohl, Will, Kuan-Chieh Wang, Joern-Henrik Jacobsen, David Duvenaud, Mohammad Norouzi, and Kevin Swersky. 2020. “Your Classifier Is Secretly an Energy Based Model and You Should Treat It Like One.†In. https://openreview.net/forum?id=Hkxzx0NtDB.\n\n\nSchut, Lisa, Oscar Key, Rory Mc Grath, Luca Costabello, Bogdan Sacaleanu, Yarin Gal, et al. 2021. “Generating Interpretable Counterfactual Explanations By Implicit Minimisation of Epistemic and Aleatoric Uncertainties.†In International Conference on Artificial Intelligence and Statistics, 1756–64. PMLR.\n\n\nStutz, David, Krishnamurthy Dj Dvijotham, Ali Taylan Cemgil, and Arnaud Doucet. 2022. “Learning Optimal Conformal Classifiers.†In. https://openreview.net/forum?id=t8O-4LKFVx." }, { "objectID": "notebooks/synthetic.html#benchmark", "href": "notebooks/synthetic.html#benchmark", "title": "3 Synthetic data", "section": "3.1 Benchmark", - "text": "3.1 Benchmark\n\n# Benchmark generators:\ngenerators = Dict(\n :wachter => GenericGenerator(opt=opt, λ=l2_λ),\n :revise => REVISEGenerator(opt=opt, λ=l2_λ),\n :greedy => GreedyGenerator(),\n)\n\n# Untrained Models:\nmodels = Dict(Symbol(\"cov$(Int(100*cov))\") => CCE.ConformalModel(conformal_model(mlp; method=:simple_inductive, coverage=cov)) for cov in cvgs)\n\n# Measures:\nmeasures = [\n CounterfactualExplanations.distance,\n CCE.distance_from_energy,\n CCE.distance_from_targets,\n CounterfactualExplanations.validity,\n]\n\n\n3.1.1 Single CE\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n3.1.2 Full Benchmark\n\nbmks = []\nfor (dataname, dataset) in datasets\n for λ in Λ, temp in temps\n _generators = deepcopy(generators)\n _generators[:cce] = CCEGenerator(temp=temp, λ=[l2_λ,λ], opt=opt)\n _generators[:energy] = CCE.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt)\n _generators[:target] = CCE.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt)\n bmk = benchmark(\n dataset; \n models=deepcopy(models), \n generators=_generators, \n measure=measures,\n suppress_training=false, dataname=dataname,\n n_individuals=5,\n initialization=:identity,\n )\n bmk.evaluation.λ .= λ\n bmk.evaluation.temperature .= temp\n push!(bmks, bmk)\n end\nend\nbmk = reduce(vcat, bmks)\n\n\nCSV.write(joinpath(output_path, \"synthetic_benchmark.csv\"), bmk())" + "text": "3.1 Benchmark\n\n# Benchmark generators:\ngenerators = Dict(\n :wachter => GenericGenerator(opt=opt, λ=l2_λ),\n :revise => REVISEGenerator(opt=opt, λ=l2_λ),\n :greedy => GreedyGenerator(),\n)\n\n# Untrained Models:\nmodels = Dict(Symbol(\"cov$(Int(100*cov))\") => ECCCE.ConformalModel(conformal_model(mlp; method=:simple_inductive, coverage=cov)) for cov in cvgs)\n\n# Measures:\nmeasures = [\n CounterfactualExplanations.distance,\n ECCCE.distance_from_energy,\n ECCCE.distance_from_targets,\n CounterfactualExplanations.validity,\n]\n\n\n3.1.1 Single CE\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n3.1.2 Full Benchmark\n\nbmks = []\nfor (dataname, dataset) in datasets\n for λ in Λ, temp in temps\n _generators = deepcopy(generators)\n _generators[:cce] = ECCCEGenerator(temp=temp, λ=[l2_λ,λ], opt=opt)\n _generators[:energy] = ECCCE.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt)\n _generators[:target] = ECCCE.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt)\n bmk = benchmark(\n dataset; \n models=deepcopy(models), \n generators=_generators, \n measure=measures,\n suppress_training=false, dataname=dataname,\n n_individuals=5,\n initialization=:identity,\n )\n bmk.evaluation.λ .= λ\n bmk.evaluation.temperature .= temp\n push!(bmks, bmk)\n end\nend\nbmk = reduce(vcat, bmks)\n\n\nCSV.write(joinpath(output_path, \"synthetic_benchmark.csv\"), bmk())" }, { "objectID": "notebooks/references.html", diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml index a1c318aaa24efd6ebab22371e0af1b3394f2a008..d5d8d17a7a0a789c1182b3436ae25b9c0f895565 100644 --- a/notebooks/Manifest.toml +++ b/notebooks/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.5" manifest_format = "2.0" -project_hash = "f47e03784c5ec1ca97608a676d4689997e56c59d" +project_hash = "bb24fa6d048fab99674a941d85c45a034b033aae" [[deps.AbstractFFTs]] deps = ["ChainRulesCore", "LinearAlgebra"] @@ -140,12 +140,6 @@ git-tree-sha1 = "19a35467a82e236ff51bc17a3a44b69ef35185a2" uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" version = "1.0.8+0" -[[deps.CCE]] -deps = ["CategoricalArrays", "ChainRules", "ConformalPrediction", "CounterfactualExplanations", "Distances", "Distributions", "Flux", "JointEnergyModels", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "Parameters", "PkgTemplates", "Plots", "Random", "SliceMap", "Statistics", "StatsBase", "StatsPlots", "Term"] -path = ".." -uuid = "0232c203-4013-4b0d-ad96-43e3e11ac3bf" -version = "0.1.0" - [[deps.CEnum]] git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -502,6 +496,12 @@ git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" version = "0.6.8" +[[deps.ECCCE]] +deps = ["CategoricalArrays", "ChainRules", "ConformalPrediction", "CounterfactualExplanations", "Distances", "Distributions", "Flux", "JointEnergyModels", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "Parameters", "PkgTemplates", "Plots", "Random", "SliceMap", "Statistics", "StatsBase", "StatsPlots", "Term"] +path = ".." +uuid = "0232c203-4013-4b0d-ad96-43e3e11ac3bf" +version = "0.1.0" + [[deps.EarCut_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "e3290f2d49e661fbd94046d7e3726ffcb2d41053" diff --git a/notebooks/Project.toml b/notebooks/Project.toml index c178b3ca0d66cd45e1f94ccabf64b0dc42fd5ef7..5e61dae429cc465ff36306b53f150ec683261131 100644 --- a/notebooks/Project.toml +++ b/notebooks/Project.toml @@ -1,6 +1,6 @@ [deps] AlgebraOfGraphics = "cbdf2221-f076-402e-a563-3d30da359d67" -CCE = "0232c203-4013-4b0d-ad96-43e3e11ac3bf" +ECCCE = "0232c203-4013-4b0d-ad96-43e3e11ac3bf" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e" diff --git a/notebooks/intro.qmd b/notebooks/intro.qmd index 33eeecb3918e2222157d79934fe97271ca4e3806..1c2ad2720767c6c5a589aa18f191edf1d4756330 100644 --- a/notebooks/intro.qmd +++ b/notebooks/intro.qmd @@ -122,14 +122,14 @@ We can generate samples from $p_{\theta}(X|y)$ following @grathwohl2020your. In #| label: fig-energy #| output: true -M = CCE.ConformalModel(conf_model, mach.fitresult) +M = ECCCE.ConformalModel(conf_model, mach.fitresult) niter = 100 nsamples = 100 plts = [] for (i,target) ∈ enumerate(counterfactual_data.y_levels) - sampler = CCE.EnergySampler(M, counterfactual_data, target; niter=niter, nsamples=100) + sampler = ECCCE.EnergySampler(M, counterfactual_data, target; niter=niter, nsamples=100) Xgen = rand(sampler, nsamples) plt = Plots.plot(M, counterfactual_data; target=target, zoom=-3,cbar=false) Plots.scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=i,shape=:star,label="X|y=$target") @@ -294,11 +294,11 @@ Plots.plot(plts..., layout=(length(cvgs),length(cvgs)), size=(2img_height*length niter = 100 nsamples = 100 -M = CCE.ConformalModel(conf_model, mach.fitresult; likelihood=:classification_multi) +M = ECCCE.ConformalModel(conf_model, mach.fitresult; likelihood=:classification_multi) plts = [] for target ∈ counterfactual_data.y_levels - sampler = CCE.EnergySampler(M, counterfactual_data, target; niter=niter, nsamples=100) + sampler = ECCCE.EnergySampler(M, counterfactual_data, target; niter=niter, nsamples=100) Xgen = rand(sampler, nsamples) plt = Plots.plot(M, counterfactual_data; target=target, zoom=-0.5,cbar=false) Plots.scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=target,shape=:star,label="X|y=$target") @@ -354,10 +354,10 @@ datasets = Dict( # Untrained Models: models = Dict( - :cov75 => CCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.75)), - :cov80 => CCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.80)), - :cov90 => CCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.90)), - :cov99 => CCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.99)), + :cov75 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.75)), + :cov80 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.80)), + :cov90 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.90)), + :cov99 => ECCCE.ConformalModel(conformal_model(clf; method=:simple_inductive, coverage=0.99)), ) ``` @@ -369,8 +369,8 @@ using CounterfactualExplanations.Evaluation: benchmark bmks = [] measures = [ CounterfactualExplanations.distance, - CCE.distance_from_energy, - CCE.distance_from_targets + ECCCE.distance_from_energy, + ECCCE.distance_from_targets ] for (dataname, dataset) in datasets bmk = benchmark( diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 31fcec4de3643555097f750b4b44f9dee4e51231..d26f3ca62ce20d7461625df854a428b0e12761c8 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -5,11 +5,20 @@ eval(setup_notebooks) # MNIST +```{julia} +function pre_process(x; noise::Float32=0.03f0) + ϵ = Float32.(randn(size(x)) * noise) + x = @.(2 * x - 1) .+ ϵ + return x +end +``` + ```{julia} # Data: n_obs = 10000 counterfactual_data = load_mnist(n_obs) X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) +X = pre_process.(X) X = table(permutedims(X)) labels = counterfactual_data.output_encoder.labels input_dim, n_obs = size(counterfactual_data.X) @@ -23,16 +32,18 @@ First, let's create a couple of image classifier architectures: # Model parameters: epochs = 100 batch_size = minimum([Int(round(n_obs/10)), 128]) -n_hidden = 200 +n_hidden = 32 activation = Flux.relu -builder = MLJFlux.@builder Flux.Chain( - Dense(n_in, n_hidden), - BatchNorm(n_hidden, activation), - Dense(n_hidden, n_hidden), - BatchNorm(n_hidden, activation), - Dense(n_hidden, n_out), -) -# builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.2, σ=activation) +# builder = MLJFlux.@builder Flux.Chain( +# Dense(n_in, n_hidden, activation), +# Dense(n_hidden, n_hidden, activation), +# Dense(n_hidden, n_hidden, activation), +# # BatchNorm(n_hidden, activation), +# # Dense(n_hidden, n_hidden), +# # BatchNorm(n_hidden, activation), +# Dense(n_hidden, n_out), +# ) +builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.2, σ=activation) # builder = MLJFlux.MLP( # hidden=( # n_hidden, @@ -51,7 +62,7 @@ mlp = NeuralNetworkClassifier( ) # Joint Energy Model: -ð’Ÿx = Uniform(0,1) +ð’Ÿx = Uniform(-1,1) ð’Ÿy = Categorical(ones(output_dim) ./ output_dim) sampler = ConditionalSampler( ð’Ÿx, ð’Ÿy, @@ -78,11 +89,11 @@ mlp_ens = EnsembleModel(model=mlp, n=5) ``` ```{julia} -cov = .90 +cov = .95 conf_model = conformal_model(jem; method=:adaptive_inductive, coverage=cov) mach = machine(conf_model, X, labels) fit!(mach) -M = CCE.ConformalModel(mach.model, mach.fitresult) +M = ECCCE.ConformalModel(mach.model, mach.fitresult) ``` ```{julia} @@ -140,12 +151,12 @@ ce_jsma = generate_counterfactual( initialization=:identity, ) -# CCE: +# ECCCE: λ=[0.0,1.0] temp=0.01 -# Generate counterfactual using CCE generator: -generator = CCEGenerator( +# Generate counterfactual using ECCCE generator: +generator = ECCCEGenerator( λ=λ, temp=temp, opt=Flux.Optimise.Adam(), @@ -157,8 +168,8 @@ ce_conformal = generate_counterfactual( converge_when=:generator_conditions, ) -# Generate counterfactual using CCE generator: -generator = CCEGenerator( +# Generate counterfactual using ECCCE generator: +generator = ECCCEGenerator( λ=λ, temp=temp, opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0), @@ -180,7 +191,7 @@ p1 = Plots.plot( plts = [p1] ces = [ce_wachter, ce_conformal, ce_jsma, ce_conformal_jsma] -_names = ["Wachter", "CCE", "JSMA", "CCE-JSMA"] +_names = ["Wachter", "ECCCE", "JSMA", "ECCCE-JSMA"] for x in zip(ces, _names) ce, _name = (x[1],x[2]) x = CounterfactualExplanations.counterfactual(ce) @@ -215,8 +226,8 @@ generators = Dict( # Measures: measures = [ CounterfactualExplanations.distance, - CCE.distance_from_energy, - CCE.distance_from_targets, + ECCCE.distance_from_energy, + ECCCE.distance_from_targets, CounterfactualExplanations.validity, ] ``` \ No newline at end of file diff --git a/notebooks/proposal.qmd b/notebooks/proposal.qmd index 5ee92707e18cec50de9b3918c6547ebdc5c4084e..bdbd45b34b2cef38cf7889a408930e95cab37ca8 100644 --- a/notebooks/proposal.qmd +++ b/notebooks/proposal.qmd @@ -188,7 +188,7 @@ If the computed value is different from zero, we can reject the null-hypothesis ## Conformal Counterfactual Explanations -In @sec-fidelity, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (CCE), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models. +In @sec-fidelity, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (ECCCE), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models. ### Minimizing Predictive Uncertainty diff --git a/notebooks/setup.jl b/notebooks/setup.jl index 6f9ae26e480d8f327c3677de96cb0c591769f9ff..5997cae6d12ea9c680bea6c7e9475c6fda393a2a 100644 --- a/notebooks/setup.jl +++ b/notebooks/setup.jl @@ -6,8 +6,8 @@ setup_notebooks = quote using AlgebraOfGraphics using AlgebraOfGraphics: Violin, BoxPlot, BarPlot using CairoMakie - using CCE - using CCE: set_size_penalty, distance_from_energy, distance_from_targets + using ECCCE + using ECCCE: set_size_penalty, distance_from_energy, distance_from_targets using Chain: @chain using ConformalPrediction using CounterfactualExplanations diff --git a/notebooks/synthetic.qmd b/notebooks/synthetic.qmd index b97802c6002d4b275e50106213fc1e92d7031d9e..0ac30edb952175fd984e4cbdf1c68fcdcc2168d5 100644 --- a/notebooks/synthetic.qmd +++ b/notebooks/synthetic.qmd @@ -60,16 +60,16 @@ for (dataname, data) in datasets conf_model = conformal_model(clf; method=:simple_inductive, coverage=cov) mach = machine(conf_model, X, y) fit!(mach) - M = CCE.ConformalModel(mach.model, mach.fitresult) + M = ECCCE.ConformalModel(mach.model, mach.fitresult) - # Set up CCE: + # Set up ECCCE: factual_label = predict_label(M, data, x)[1] target_label = data.y_levels[data.y_levels .!= factual_label][1] for λ in Λ, temp in temps - # CCE for given classifier, coverage, temperature and λ: - generator = CCEGenerator(temp=temp, λ=[l2_λ,λ], opt=opt) + # ECCCE for given classifier, coverage, temperature and λ: + generator = ECCCEGenerator(temp=temp, λ=[l2_λ,λ], opt=opt) @assert predict_label(M, data, x) != target_label ce = try generate_counterfactual( @@ -158,13 +158,13 @@ generators = Dict( ) # Untrained Models: -models = Dict(Symbol("cov$(Int(100*cov))") => CCE.ConformalModel(conformal_model(mlp; method=:simple_inductive, coverage=cov)) for cov in cvgs) +models = Dict(Symbol("cov$(Int(100*cov))") => ECCCE.ConformalModel(conformal_model(mlp; method=:simple_inductive, coverage=cov)) for cov in cvgs) # Measures: measures = [ CounterfactualExplanations.distance, - CCE.distance_from_energy, - CCE.distance_from_targets, + ECCCE.distance_from_energy, + ECCCE.distance_from_targets, CounterfactualExplanations.validity, ] ``` @@ -187,7 +187,7 @@ for (dataname, data) in datasets # Model training: M = train(M, data) - # Set up CCE: + # Set up ECCCE: factual_label = predict_label(M, data, x)[1] target_label = data.y_levels[data.y_levels .!= factual_label][1] @@ -195,13 +195,13 @@ for (dataname, data) in datasets # Generators: _generators = deepcopy(generators) - _generators[:cce] = CCEGenerator(temp=_temp, λ=[l2_λ,λ], opt=opt) - _generators[:energy] = CCE.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt) - _generators[:target] = CCE.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt) + _generators[:cce] = ECCCEGenerator(temp=_temp, λ=[l2_λ,λ], opt=opt) + _generators[:energy] = ECCCE.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt) + _generators[:target] = ECCCE.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt) for (gen_name, gen) in _generators - # CCE for given models, λ and generator: + # ECCCE for given models, λ and generator: @assert predict_label(M, data, x) != target_label ce = try generate_counterfactual( @@ -300,9 +300,9 @@ bmks = [] for (dataname, dataset) in datasets for λ in Λ, temp in temps _generators = deepcopy(generators) - _generators[:cce] = CCEGenerator(temp=temp, λ=[l2_λ,λ], opt=opt) - _generators[:energy] = CCE.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt) - _generators[:target] = CCE.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt) + _generators[:cce] = ECCCEGenerator(temp=temp, λ=[l2_λ,λ], opt=opt) + _generators[:energy] = ECCCE.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt) + _generators[:target] = ECCCE.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt) bmk = benchmark( dataset; models=deepcopy(models), diff --git a/paper/paper.tex b/paper/paper.tex index 50fc37d12de8157060f6cff4e7e34995344b5323..13799de77ebc15c684582cf5e28ce0d018267804 100644 --- a/paper/paper.tex +++ b/paper/paper.tex @@ -266,7 +266,7 @@ The fact that conformal classifiers produce set-valued predictions introduces a where $\kappa \in \{0,1\}$ is a hyper-parameter and $C_{\theta,\mathbf{y}}(\mathbf{x}_i;\alpha)$ can be interpreted as the probability of label $\mathbf{y}$ being included in the prediction set. Formally, it is defined as $C_{\theta,\mathbf{y}}(\mathbf{x}_i;\alpha):=\sigma\left((s(\mathbf{x}_i,\mathbf{y})-\alpha) T^{-1}\right)$ for $\mathbf{y}\in\mathcal{Y}$ where $\sigma$ is the sigmoid function and $T$ is a hyper-parameter used for temperature scaling \citep{stutz2022learning}. -Penalizing the set size in this way is in principal enough to train efficient conformal classifiers \citep{stutz2022learning}. As we explained above, the set size is also closely linked to predictive uncertainty at the local level. This makes the smooth penalty defined in Equation~\ref{eq:setsize} useful in the context of meeting our objective of generating plausible counterfactuals. In particular, we adapt Equation~\ref{eq:general} to define the baseline objective for Conformal Counterfactual Explanations (CCE): +Penalizing the set size in this way is in principal enough to train efficient conformal classifiers \citep{stutz2022learning}. As we explained above, the set size is also closely linked to predictive uncertainty at the local level. This makes the smooth penalty defined in Equation~\ref{eq:setsize} useful in the context of meeting our objective of generating plausible counterfactuals. In particular, we adapt Equation~\ref{eq:general} to define the baseline objective for Conformal Counterfactual Explanations (ECCCE): \begin{equation}\label{eq:cce} \begin{aligned} @@ -276,7 +276,7 @@ Penalizing the set size in this way is in principal enough to train efficient co Since we can still retrieve unperturbed softmax outputs from our conformal classifier $M_{\theta}$, we are free to work with any loss function of our choice. For example, we could use standard cross-entropy for $\text{yloss}$. -In order to generate prediction sets $C_{\theta}(f(\mathbf{Z}^\prime);\alpha)$ for any Black Box Model we merely need to perform a single calibration pass through a holdout set $\mathcal{D}_{\text{cal}}$. Arguably, data is typically abundant and in most applications practitioners tend to hold out a test data set anyway. Our proposed approach for CCE therefore removes the restriction on the family of predictive models, at the small cost of reserving a subset of the available data for calibration. +In order to generate prediction sets $C_{\theta}(f(\mathbf{Z}^\prime);\alpha)$ for any Black Box Model we merely need to perform a single calibration pass through a holdout set $\mathcal{D}_{\text{cal}}$. Arguably, data is typically abundant and in most applications practitioners tend to hold out a test data set anyway. Our proposed approach for ECCCE therefore removes the restriction on the family of predictive models, at the small cost of reserving a subset of the available data for calibration. \section{Experiments} diff --git a/src/CCE.jl b/src/ECCCE.jl similarity index 67% rename from src/CCE.jl rename to src/ECCCE.jl index 809408f16ccc7b2e5619e2fcb5e163eef83c5501..8a031fffa21f2d3f65f938b79725f348dcfc8703 100644 --- a/src/CCE.jl +++ b/src/ECCCE.jl @@ -1,4 +1,4 @@ -module CCE +module ECCCE using CounterfactualExplanations import MLJModelInterface as MMI @@ -9,6 +9,6 @@ include("losses.jl") include("generator.jl") include("sampling.jl") -export CCEGenerator, EnergySampler, set_size_penalty, distance_from_energy +export ECCCEGenerator, EnergySampler, set_size_penalty, distance_from_energy end \ No newline at end of file diff --git a/src/generator.jl b/src/generator.jl index d379f7512483bba945de146d2404cf28e9b54be7..ac598d483e8a6f246210b4fa47daf436ba508b14 100644 --- a/src/generator.jl +++ b/src/generator.jl @@ -1,25 +1,43 @@ using CounterfactualExplanations.Objectives -"Constructor for `CCEGenerator`." -function CCEGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], κ::Real=1.0, temp::Real=0.05, kwargs...) +"Constructor for `ECCCEGenerator`." +function ECCCEGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], κ::Real=1.0, temp::Real=0.05, kwargs...) function _set_size_penalty(ce::AbstractCounterfactualExplanation) - return CCE.set_size_penalty(ce; κ=κ, temp=temp) + return ECCCE.set_size_penalty(ce; κ=κ, temp=temp) end _penalties = [Objectives.distance_l2, _set_size_penalty] λ = λ isa AbstractFloat ? [0.0, λ] : λ return Generator(; penalty=_penalties, λ=λ, kwargs...) end +"Constructor for `ECECCCEGenerator`: Energy Constrained Conformal Counterfactual Explanation Generator." +function ECECCCEGenerator(; + λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0, 1.0], + κ::Real=1.0, + temp::Real=0.5, + η::Union{Nothing,Real}=nothing, + n::Union{Nothing,Int}=nothing, + opt::Flux.Optimise.AbstractOptimiser=CounterfactualExplanations.Generators.JSMADescent(η=η,n=n), + kwargs... +) + function _set_size_penalty(ce::AbstractCounterfactualExplanation) + return ECCCE.set_size_penalty(ce; κ=κ, temp=temp) + end + _penalties = [Objectives.distance_l2, _set_size_penalty, ECCCE.distance_from_energy] + λ = λ isa AbstractFloat ? [0.0, λ, λ] : λ + return Generator(; penalty=_penalties, λ=λ, opt=opt, kwargs...) +end + "Constructor for `EnergyDrivenGenerator`." function EnergyDrivenGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], kwargs...) - _penalties = [Objectives.distance_l2, CCE.distance_from_energy] + _penalties = [Objectives.distance_l2, ECCCE.distance_from_energy] λ = λ isa AbstractFloat ? [0.0, λ] : λ return Generator(; penalty=_penalties, λ=λ, kwargs...) end "Constructor for `TargetDrivenGenerator`." function TargetDrivenGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], kwargs...) - _penalties = [Objectives.distance_l2, CCE.distance_from_targets] + _penalties = [Objectives.distance_l2, ECCCE.distance_from_targets] λ = λ isa AbstractFloat ? [0.0, λ] : λ return Generator(; penalty=_penalties, λ=λ, kwargs...) end \ No newline at end of file diff --git a/src/penalties.jl b/src/penalties.jl index 30ae653c91f21d06eb794ffa3369e9c5cdb6f9a7..37474c9e43555dc497bdb9f766ec2915d45d5e63 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -42,7 +42,7 @@ function distance_from_energy( ignore_derivatives() do _dict = ce.params if !(:energy_sampler ∈ collect(keys(_dict))) - _dict[:energy_sampler] = CCE.EnergySampler(ce; kwargs...) + _dict[:energy_sampler] = ECCCE.EnergySampler(ce; kwargs...) end sampler = _dict[:energy_sampler] push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer)) diff --git a/test/runtests.jl b/test/runtests.jl index d86e70bf27b1d6bff8d3469daf0db85ca078d5c7..569772e439080dfee852354167d0926c7c60170f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ -using CCE +using ECCCE using Test -@testset "CCE.jl" begin +@testset "ECCCE.jl" begin # Write your tests here. end diff --git a/www/cce_mnist.png b/www/cce_mnist.png index 07f2ae43c15f6032db97bbfdfdf748de88051151..3db6423c613ebc10174d84b0ae8f6f3cf8601ff0 100644 Binary files a/www/cce_mnist.png and b/www/cce_mnist.png differ