diff --git a/artifacts/mnist_classifier.jls b/artifacts/mnist_classifier.jls new file mode 100644 index 0000000000000000000000000000000000000000..1a66ab701028818f831939d432d2e82f2efaecae Binary files /dev/null and b/artifacts/mnist_classifier.jls differ diff --git a/artifacts/mnist_vae.jls b/artifacts/mnist_vae.jls new file mode 100644 index 0000000000000000000000000000000000000000..c1398fc04eba1b362806f624d0a8b4064277d718 Binary files /dev/null and b/artifacts/mnist_vae.jls differ diff --git a/artifacts/mnist_vae_weak.jls b/artifacts/mnist_vae_weak.jls new file mode 100644 index 0000000000000000000000000000000000000000..18d2a0577e38d640e13ce207ad9497d034cc3a91 Binary files /dev/null and b/artifacts/mnist_vae_weak.jls differ diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml index 29fae8d01cc8e827d0242b3c218480ddac655b2c..49e756b6a313eb0a19fdc90dd8e6cea4dcb1e0f8 100644 --- a/notebooks/Manifest.toml +++ b/notebooks/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.5" manifest_format = "2.0" -project_hash = "c8134a823a7f1e6a7b6ed8be16a2c99858bd6d32" +project_hash = "38a34f0cc2271777d881bdfe6cc3a2c09bd2663c" [[deps.AbstractFFTs]] deps = ["ChainRulesCore", "LinearAlgebra"] @@ -10,6 +10,11 @@ git-tree-sha1 = "16b6dbc4cf7caee4e1e75c49485ec67b667098a0" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" version = "1.3.1" +[[deps.AbstractTrees]] +git-tree-sha1 = "faa260e4cb5aba097a73fab382dd4b5819d8ec8c" +uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +version = "0.4.4" + [[deps.Accessors]] deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "StaticArrays", "Test"] git-tree-sha1 = "beabc31fa319f9de4d16372bff31b4801e43d32c" @@ -118,7 +123,7 @@ uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" version = "1.0.8+0" [[deps.CCE]] -deps = ["CategoricalArrays", "ChainRules", "ConformalPrediction", "CounterfactualExplanations", "Distances", "Distributions", "Flux", "JointEnergyModels", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Parameters", "Plots", "Random", "SliceMap", "Statistics", "StatsBase", "StatsPlots"] +deps = ["CategoricalArrays", "ChainRules", "ConformalPrediction", "CounterfactualExplanations", "Distances", "Distributions", "Flux", "JointEnergyModels", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Parameters", "Plots", "Random", "SliceMap", "Statistics", "StatsBase", "StatsPlots", "Term"] path = ".." uuid = "0232c203-4013-4b0d-ad96-43e3e11ac3bf" version = "0.1.0" @@ -230,6 +235,12 @@ git-tree-sha1 = "7ebbd653f74504447f1c33b91cd706a69a1b189f" uuid = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" version = "0.14.4" +[[deps.CodeTracking]] +deps = ["InteractiveUtils", "UUIDs"] +git-tree-sha1 = "d57c99cc7e637165c81b30eb268eabe156a45c49" +uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" +version = "1.2.2" + [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] git-tree-sha1 = "9c209fb7536406834aa938fb149964b985de6c83" @@ -690,6 +701,12 @@ git-tree-sha1 = "129acf094d168394e80ee1dc4bc06ec835e510a3" uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566" version = "2.8.1+1" +[[deps.Highlights]] +deps = ["DocStringExtensions", "InteractiveUtils", "REPL"] +git-tree-sha1 = "0341077e8a6b9fc1c2ea5edc1e93a956d2aec0c7" +uuid = "eafb193a-b7ab-5a9e-9068-77385905fa72" +version = "0.5.2" + [[deps.HypergeometricFunctions]] deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions", "Test"] git-tree-sha1 = "709d864e3ed6e3545230601f94e11ebc65994641" @@ -1251,6 +1268,11 @@ git-tree-sha1 = "91a48569383df24f0fd2baf789df2aade3d0ad80" uuid = "6f286f6a-111f-5878-ab1e-185364afe411" version = "0.10.1" +[[deps.MyterialColors]] +git-tree-sha1 = "01d8466fb449436348999d7c6ad740f8f853a579" +uuid = "1c23619d-4212-4747-83aa-717207fae70f" +version = "0.3.0" + [[deps.NLSolversBase]] deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"] git-tree-sha1 = "a0b464d183da839699f4c79e7606d9d186ec172c" @@ -1874,6 +1896,12 @@ git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" version = "0.1.1" +[[deps.Term]] +deps = ["AbstractTrees", "CodeTracking", "Dates", "Highlights", "InteractiveUtils", "Logging", "Markdown", "MyterialColors", "OrderedCollections", "Parameters", "ProgressLogging", "REPL", "SnoopPrecompile", "Tables", "UUIDs", "Unicode", "UnicodeFun"] +git-tree-sha1 = "373d65207cb8de6d2e7bd32b89476e760c6edc4d" +uuid = "22787eb5-b846-44ae-b979-8e399b8463ab" +version = "2.0.2" + [[deps.Test]] deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/notebooks/Project.toml b/notebooks/Project.toml index 7de2a71557207ad46fb5353699faa82897259c1f..77c871083b5a16a47f235f428ed5517aac857eb2 100644 --- a/notebooks/Project.toml +++ b/notebooks/Project.toml @@ -9,5 +9,7 @@ JointEnergyModels = "48c56d24-211d-4463-bbc0-7a701b291131" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" diff --git a/notebooks/proposal.qmd b/notebooks/proposal.qmd index 4c4aa5c74ac5198834b6c296e5502792031ca1d2..5ee92707e18cec50de9b3918c6547ebdc5c4084e 100644 --- a/notebooks/proposal.qmd +++ b/notebooks/proposal.qmd @@ -24,41 +24,56 @@ counterfactual_data = load_mnist() X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) input_dim, n_obs = size(counterfactual_data.X) M = load_mnist_mlp() + # Target: factual_label = 8 x = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) target = 3 factual = predict_label(M, counterfactual_data, x)[1] +γ = 0.9 +T = 50 +``` + +```{julia} # Search: -n_ce = 3 generator = GenericGenerator() -ces = generate_counterfactual(x, target, counterfactual_data, M, generator; num_counterfactuals=n_ce) +ce_wachter = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, +) +generator = GreedyGenerator(η=5.0) +ce_jsma = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, +) ``` ```{julia} -image_size = 200 p1 = plot( convert2image(MNIST, reshape(x,28,28)), axis=nothing, - size=(image_size, image_size), + size=(img_height, img_height), title="Factual" ) plts = [p1] -counterfactuals = CounterfactualExplanations.counterfactual(ces) -phat = target_probs(ces) +ces = zip([ce_wachter,ce_jsma]) +counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces)) +phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces)) for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3)) ce, _phat = (x[1],x[2]) _title = "p(y=$(target)|x′)=$(round(_phat[1]; digits=3))" plt = plot( convert2image(MNIST, reshape(ce,28,28)), axis=nothing, - size=(image_size, image_size), + size=(img_height, img_height), title=_title ) plts = [plts..., plt] end -plt = plot(plts...; size=(image_size * (n_ce + 1),image_size), layout=(1,(n_ce + 1))) +plt = plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) savefig(plt, joinpath(www_path, "you_may_not_like_it.png")) ``` @@ -76,6 +91,57 @@ Formally, if $x \sim \mathcal{X}$ and for the corresponding counterfactual we ha In the context of Algorithmic Recourse, it makes sense to strive for plausible counterfactuals, since anything else would essentially require individuals to move to out-of-distribution states. But it is worth noting that our ambition to meet this goal, may have implications on our ability to faithfully explain the behaviour of the underlying black-box model (arguably our principal goal). By essentially decoupling the task of learning plausible representations of the data from the model itself, we open ourselves up to vulnerabilities. Using a separate generative model to learn $\mathcal{X}$, for example, has very serious implications for the generated counterfactuals. @fig-latent compares the results of applying REVISE [@joshi2019realistic] to MNIST data using two different Variational Auto-Encoders: while the counterfactual generated using an expressive (strong) VAE is compelling, the result relying on a less expressive (weak) VAE is not even valid. In this latter case, the decoder step of the VAE fails to yield values in $\mathcal{X}$ and hence the counterfactual search in the learned latent space is doomed. +```{julia} +using CounterfactualExplanations.Models: load_mnist_vae +vae = load_mnist_vae() +vae_weak = load_mnist_vae(;strong=false) +Serialization.serialize(joinpath(output_path,"mnist_classifier.jls"), M) +Serialization.serialize(joinpath(output_path,"mnist_vae.jls"), vae) +Serialization.serialize(joinpath(output_path,"mnist_vae_weak.jls"), vae_weak) +``` + +```{julia} +# Define generator: +generator = REVISEGenerator( + opt = Descent(0.1), + λ=0.01 +) +# Generate recourse: +counterfactual_data.generative_model = vae # assign generative model +ce = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, +) +counterfactual_data = deepcopy(counterfactual_data) +counterfactual_data.generative_model = vae_weak +ce_weak = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, +) +``` + +```{julia} +ces = zip([ce,ce_weak]) +counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces)) +phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces)) +plts = [p1] +for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3)) + ce, _phat = (x[1],x[2]) + _title = "p(y=$(target)|x′)=$(round(_phat[1]; digits=3))" + plt = plot( + convert2image(MNIST, reshape(ce,28,28)), + axis=nothing, + size=(img_height, img_height), + title=_title + ) + plts = [plts..., plt] +end +plt = plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) +savefig(plt, joinpath(www_path, "surrogate_gone_wrong.png")) +``` + {#fig-latent} > Here it would be nice to have another example where we poison the data going into the generative model to hide biases present in the data (e.g. Boston housing). diff --git a/notebooks/setup.jl b/notebooks/setup.jl index 00f4e67b8c6a9878ab52e0eedf07fc273b067916..080ced2d94737fec2127d0694fc846aac99d4d54 100644 --- a/notebooks/setup.jl +++ b/notebooks/setup.jl @@ -19,13 +19,16 @@ setup_notebooks = quote using MLDatasets: convert2image using MLJBase using MLJFlux + using MLUtils using Plots using Random + using Serialization # Setup: theme(:wong) Random.seed!(2023) - www_path = "notebooks/www" + www_path = "www" + output_path = "artifacts" img_height = 300 end; \ No newline at end of file diff --git a/paper/paper.pdf b/paper/paper.pdf index 4a9de934a99d1fa56f9c1d8d5def0e76693c7479..4f9a26c455d005249c1d8091009f6d3b5a72fd1a 100644 Binary files a/paper/paper.pdf and b/paper/paper.pdf differ diff --git a/paper/paper.tex b/paper/paper.tex index 7c6c3ff02905619f0ebf654db7075e3c0930db32..1ee860d8ef565ca51ab2cf54ee75df981406a38b 100644 --- a/paper/paper.tex +++ b/paper/paper.tex @@ -127,13 +127,13 @@ To properly serve both AI practitioners and individuals affected by AI decision- \centering \begin{minipage}[t]{0.45\textwidth} \centering - \includegraphics[width=\textwidth]{www/you_may_not_like_it.png} - \caption{You may not like it, but this is what stripped-down counterfactuals look like. Counterfactuals for turning an 8 (eight) into a 3 (three) using different methodologies. Left: Using the generic approach proposed in \citet{wachter2017counterfactual}. Center: Using REVISE \citep{joshi2019realistic} with a well-specified surrogate. Right: Using REVISE \citep{joshi2019realistic} with a poorly specified surrogate.}\label{fig:adv} + \includegraphics[width=\textwidth]{../www/you_may_not_like_it.png} + \caption{You may not like it, but this is what stripped-down counterfactuals look like. Counterfactuals for turning an 8 (eight) into a 3 (three): original image (left); counterfactual produced using \citet{wachter2017counterfactual} (center); and a counterfactual produced using JSMA-based approach introduced by \citep{schut2021generating}.}\label{fig:adv} \end{minipage}\hfill \begin{minipage}[t]{0.45\textwidth} \centering - \includegraphics[width=\textwidth]{www/mnist_9to4_latent.png} - \caption{Using surrogates introduces a dependency. Counterfactuals for turning a 9 (nine) into a 4 (four) using REVISE: original image (left); counterfactual using a well-specified surrogate (center); and counterfactual using a poorly specified surrogate (right).}\label{fig:vae} + \includegraphics[width=\textwidth]{../www/surrogate_gone_wrong.png} + \caption{Using surrogates can improve plausibility, but also increases vulnerability. Counterfactuals for turning an 8 (eight) into a 3 (three): original image (left); counterfactual produced using REVISE \citep{joshi2019realistic} with a well-specified surrogate (center); and a counterfactual produced using REVISE \citep{joshi2019realistic} with a poorly specified surrogate (right).}\label{fig:vae} \end{minipage} \end{figure} diff --git a/paper/www/mnist_9to4_latent.png b/paper/www/mnist_9to4_latent.png deleted file mode 100644 index 8f6b3a53c7295ef3678d841e79dc5e0bd38f6e3c..0000000000000000000000000000000000000000 Binary files a/paper/www/mnist_9to4_latent.png and /dev/null differ diff --git a/paper/www/you_may_not_like_it.png b/paper/www/you_may_not_like_it.png deleted file mode 100644 index 78cc0733619723f5a0b54693a978ae2ce5123cbf..0000000000000000000000000000000000000000 Binary files a/paper/www/you_may_not_like_it.png and /dev/null differ diff --git a/www/mnist_factual.png b/www/mnist_factual.png new file mode 100644 index 0000000000000000000000000000000000000000..2efbd44a9ec1e524ca9d8bcfdaa3c7a05c8ac782 Binary files /dev/null and b/www/mnist_factual.png differ diff --git a/www/mnist_vae_strong.png b/www/mnist_vae_strong.png new file mode 100644 index 0000000000000000000000000000000000000000..61cdc0381660f21f38de92b0e485d31e2a81c80d Binary files /dev/null and b/www/mnist_vae_strong.png differ diff --git a/www/mnist_vae_weak.png b/www/mnist_vae_weak.png new file mode 100644 index 0000000000000000000000000000000000000000..ac6e386efba06a5d95150188f6d3f3a51e40bf62 Binary files /dev/null and b/www/mnist_vae_weak.png differ diff --git a/www/surrogate_gone_wrong.png b/www/surrogate_gone_wrong.png new file mode 100644 index 0000000000000000000000000000000000000000..8c10938c334102a5516e9293a7ea85ab798f3006 Binary files /dev/null and b/www/surrogate_gone_wrong.png differ diff --git a/www/you_may_not_like_it.png b/www/you_may_not_like_it.png new file mode 100644 index 0000000000000000000000000000000000000000..5373f594f55dbf787d66dacba7ee86ebe0e64296 Binary files /dev/null and b/www/you_may_not_like_it.png differ