From 136467dcb7622518cf1d70dacc384f1cde65e9a6 Mon Sep 17 00:00:00 2001
From: pat-alt <altmeyerpat@gmail.com>
Date: Tue, 28 Feb 2023 14:37:00 +0100
Subject: [PATCH] some work on configurable loss function

---
 Manifest.toml             | 205 +++++++++++++++++++++++---------------
 notebooks/conformal.qmd   |  45 +++++++++
 src/ConformalGenerator.jl |  44 ++++++--
 3 files changed, 204 insertions(+), 90 deletions(-)

diff --git a/Manifest.toml b/Manifest.toml
index ec1a4d53..d739fbc5 100644
--- a/Manifest.toml
+++ b/Manifest.toml
@@ -12,15 +12,15 @@ version = "1.2.1"
 
 [[deps.Accessors]]
 deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "StaticArrays", "Test"]
-git-tree-sha1 = "b9661b900b50ba475145b311a9a0ef9d2a9c85ea"
+git-tree-sha1 = "beabc31fa319f9de4d16372bff31b4801e43d32c"
 uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
-version = "0.1.26"
+version = "0.1.28"
 
 [[deps.Adapt]]
-deps = ["LinearAlgebra"]
-git-tree-sha1 = "0310e08cb19f5da31d08341c6120c047598f5b9c"
+deps = ["LinearAlgebra", "Requires"]
+git-tree-sha1 = "cc37d689f599e8df4f464b2fa3870ff7db7492ef"
 uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
-version = "3.5.0"
+version = "3.6.1"
 
 [[deps.ArgCheck]]
 git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4"
@@ -49,20 +49,25 @@ git-tree-sha1 = "5ba6c757e8feccf03a1554dfaf3e26b3cfc7fd5e"
 uuid = "68821587-b530-5797-8361-c406ea357684"
 version = "3.5.1+1"
 
-[[deps.ArrayInterfaceCore]]
-deps = ["LinearAlgebra", "SnoopPrecompile", "SparseArrays", "SuiteSparse"]
-git-tree-sha1 = "e5f08b5689b1aad068e01751889f2f615c7db36d"
-uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2"
-version = "0.1.29"
+[[deps.ArrayInterface]]
+deps = ["Adapt", "LinearAlgebra", "Requires", "SnoopPrecompile", "SparseArrays", "SuiteSparse"]
+git-tree-sha1 = "ec9c36854b569323551a6faf2f31fda15e3459a7"
+uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
+version = "7.2.0"
 
 [[deps.Artifacts]]
 uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
 
 [[deps.BFloat16s]]
 deps = ["LinearAlgebra", "Printf", "Random", "Test"]
-git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072"
+git-tree-sha1 = "dbf84058d0a8cbbadee18d25cf606934b22d7c66"
 uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
-version = "0.2.0"
+version = "0.4.2"
+
+[[deps.BSON]]
+git-tree-sha1 = "86e9781ac28f4e80e9b98f7f96eae21891332ac2"
+uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
+version = "0.3.6"
 
 [[deps.BangBang]]
 deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"]
@@ -112,10 +117,34 @@ uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
 version = "0.10.9"
 
 [[deps.CUDA]]
-deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
-git-tree-sha1 = "6717cb9a3425ebb7b31ca4f832823615d175f64a"
+deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions"]
+git-tree-sha1 = "edff14c60784c8f7191a62a23b15a421185bc8a8"
 uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
-version = "3.13.1"
+version = "4.0.1"
+
+[[deps.CUDA_Driver_jll]]
+deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"]
+git-tree-sha1 = "75d7896d1ec079ef10d3aee8f3668c11354c03a1"
+uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc"
+version = "0.2.0+0"
+
+[[deps.CUDA_Runtime_Discovery]]
+deps = ["Libdl"]
+git-tree-sha1 = "58dd8ec29f54f08c04b052d2c2fa6760b4f4b3a4"
+uuid = "1af6417a-86b4-443c-805f-a4643ffb695f"
+version = "0.1.1"
+
+[[deps.CUDA_Runtime_jll]]
+deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"]
+git-tree-sha1 = "d3e6ccd30f84936c1a3a53d622d85d7d3f9b9486"
+uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
+version = "0.2.3+2"
+
+[[deps.CUDNN_jll]]
+deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"]
+git-tree-sha1 = "57011df4fce448828165e566af9befa2ea94350a"
+uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645"
+version = "8.6.0+3"
 
 [[deps.Cairo_jll]]
 deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"]
@@ -143,9 +172,9 @@ version = "0.1.9"
 
 [[deps.ChainRules]]
 deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"]
-git-tree-sha1 = "fdde4d8a31cf82b1d136cf6cb53924e8744a832b"
+git-tree-sha1 = "7d20c2fb8ab838e41069398685e7b6b5f89ed85b"
 uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
-version = "1.47.0"
+version = "1.48.0"
 
 [[deps.ChainRulesCore]]
 deps = ["Compat", "LinearAlgebra", "SparseArrays"]
@@ -155,9 +184,9 @@ version = "1.15.7"
 
 [[deps.ChangesOfVariables]]
 deps = ["ChainRulesCore", "LinearAlgebra", "Test"]
-git-tree-sha1 = "844b061c104c408b24537482469400af6075aae4"
+git-tree-sha1 = "485193efd2176b88e6622a39a246f8c5b600e74e"
 uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
-version = "0.1.5"
+version = "0.1.6"
 
 [[deps.Chemfiles]]
 deps = ["BinaryProvider", "Chemfiles_jll", "DocStringExtensions", "Libdl"]
@@ -234,16 +263,16 @@ uuid = "ed09eef8-17a6-5b46-8889-db040fac31e3"
 version = "0.3.2"
 
 [[deps.ConformalPrediction]]
-deps = ["CategoricalArrays", "Flux", "LinearAlgebra", "MLJBase", "MLJModelInterface", "NaturalSort", "Plots", "Statistics"]
+deps = ["CategoricalArrays", "ChainRules", "Flux", "LinearAlgebra", "MLJBase", "MLJModelInterface", "NaturalSort", "Plots", "Statistics"]
 path = "../ConformalPrediction.jl"
 uuid = "98bfc277-1877-43dc-819b-a3e38c30242f"
 version = "0.1.5"
 
 [[deps.ConstructionBase]]
 deps = ["LinearAlgebra"]
-git-tree-sha1 = "fb21ddd70a051d882a1686a5a550990bbe371a95"
+git-tree-sha1 = "89a9db8d28102b094992472d333674bd1a83ce2a"
 uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
-version = "1.4.1"
+version = "1.5.1"
 
 [[deps.ContextVariablesX]]
 deps = ["Compat", "Logging", "UUIDs"]
@@ -260,7 +289,7 @@ version = "0.6.2"
 deps = ["CSV", "CUDA", "CategoricalArrays", "DataFrames", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "Parameters", "PkgTemplates", "Plots", "ProgressMeter", "Random", "Serialization", "SliceMap", "Statistics", "StatsBase", "Tables", "UMAP"]
 path = "../CounterfactualExplanations.jl"
 uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
-version = "0.1.6"
+version = "0.1.7"
 
 [[deps.Crayons]]
 git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15"
@@ -322,9 +351,9 @@ version = "1.1.0"
 
 [[deps.DiffRules]]
 deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"]
-git-tree-sha1 = "c5b6685d53f933c11404a3ae9822afe30d522494"
+git-tree-sha1 = "a4ad7ef19d2cdc2eff57abbbe68032b1cd0bd8f8"
 uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
-version = "1.12.2"
+version = "1.13.0"
 
 [[deps.Distances]]
 deps = ["LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI"]
@@ -338,9 +367,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
 
 [[deps.Distributions]]
 deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"]
-git-tree-sha1 = "74911ad88921455c6afcad1eefa12bd7b1724631"
+git-tree-sha1 = "d71264a7b9a95dca3b8fff4477d94a837346c545"
 uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
-version = "0.25.80"
+version = "0.25.84"
 
 [[deps.DocStringExtensions]]
 deps = ["LibGit2"]
@@ -416,10 +445,10 @@ uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
 version = "0.13.7"
 
 [[deps.FiniteDiff]]
-deps = ["ArrayInterfaceCore", "LinearAlgebra", "Requires", "Setfield", "SparseArrays", "StaticArrays"]
-git-tree-sha1 = "04ed1f0029b6b3af88343e439b995141cb0d0b8d"
+deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays", "StaticArrays"]
+git-tree-sha1 = "ed1b56934a2f7a65035976985da71b6a65b4f2cf"
 uuid = "6a86dc24-6348-571c-b903-95158fe2bd41"
-version = "2.17.0"
+version = "2.18.0"
 
 [[deps.FixedPointNumbers]]
 deps = ["Statistics"]
@@ -429,9 +458,9 @@ version = "0.8.4"
 
 [[deps.Flux]]
 deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Zygote"]
-git-tree-sha1 = "c258e51850ac1fdc465f62380a61995d4a66d603"
+git-tree-sha1 = "4ff3a1d7b0dd38f2fc38e813bc801f817639c1f2"
 uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
-version = "0.13.12"
+version = "0.13.13"
 
 [[deps.FoldsThreads]]
 deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"]
@@ -453,9 +482,9 @@ version = "0.4.2"
 
 [[deps.ForwardDiff]]
 deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
-git-tree-sha1 = "a69dd6db8a809f78846ff259298678f0d6212180"
+git-tree-sha1 = "00e252f4d706b3d55a8863432e742bf5717b498d"
 uuid = "f6369f11-7733-5829-9624-2563aa707210"
-version = "0.10.34"
+version = "0.10.35"
 
 [[deps.FreeType2_jll]]
 deps = ["Artifacts", "Bzip2_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"]
@@ -476,9 +505,9 @@ version = "1.1.3"
 
 [[deps.Functors]]
 deps = ["LinearAlgebra"]
-git-tree-sha1 = "61fa9cf802d35fe1b5b8ea9fbaac4b8f020d19b1"
+git-tree-sha1 = "7ed0833a55979d3d2658a60b901469748a6b9a7c"
 uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
-version = "0.4.2"
+version = "0.4.3"
 
 [[deps.Future]]
 deps = ["Random"]
@@ -492,15 +521,15 @@ version = "3.3.8+0"
 
 [[deps.GPUArrays]]
 deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"]
-git-tree-sha1 = "4dfaff044eb2ce11a897fecd85538310e60b91e6"
+git-tree-sha1 = "a28f752ffab0ccd6660fc7af5ad1c9ad176f45f7"
 uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
-version = "8.6.2"
+version = "8.6.3"
 
 [[deps.GPUArraysCore]]
 deps = ["Adapt"]
-git-tree-sha1 = "57f7cde02d7a53c9d1d28443b9f11ac5fbe7ebc9"
+git-tree-sha1 = "1cd7f0af1aa58abc02ea1d872953a97359cb87fa"
 uuid = "46192b85-c4d5-4398-a991-12ede77f4527"
-version = "0.1.3"
+version = "0.1.4"
 
 [[deps.GPUCompiler]]
 deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"]
@@ -562,9 +591,9 @@ version = "1.0.2"
 
 [[deps.HDF5]]
 deps = ["Compat", "HDF5_jll", "Libdl", "Mmap", "Random", "Requires", "UUIDs"]
-git-tree-sha1 = "b5df7c3cab3a00c33c2e09c6bd23982a75e2fbb2"
+git-tree-sha1 = "3dab31542b3da9f25a6a1d11159d4af8fdce7d67"
 uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
-version = "0.16.13"
+version = "0.16.14"
 
 [[deps.HDF5_jll]]
 deps = ["Artifacts", "JLLWrappers", "LibCURL_jll", "Libdl", "OpenSSL_jll", "Pkg", "Zlib_jll"]
@@ -609,10 +638,10 @@ uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534"
 version = "0.9.4"
 
 [[deps.ImageShow]]
-deps = ["Base64", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"]
-git-tree-sha1 = "b563cf9ae75a635592fc73d3eb78b86220e55bd8"
+deps = ["Base64", "ColorSchemes", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"]
+git-tree-sha1 = "ce28c68c900eed3cdbfa418be66ed053e54d4f56"
 uuid = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
-version = "0.3.6"
+version = "0.3.7"
 
 [[deps.Inflate]]
 git-tree-sha1 = "5cd07aab533df5170988219191dfad0519391428"
@@ -657,9 +686,9 @@ uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
 version = "1.2.0"
 
 [[deps.IrrationalConstants]]
-git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151"
+git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2"
 uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
-version = "0.1.1"
+version = "0.2.2"
 
 [[deps.IteratorInterfaceExtensions]]
 git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
@@ -697,10 +726,10 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
 version = "1.12.0"
 
 [[deps.JpegTurbo_jll]]
-deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
-git-tree-sha1 = "b53380851c6e6664204efb2e62cd24fa5c47e4ba"
+deps = ["Artifacts", "JLLWrappers", "Libdl"]
+git-tree-sha1 = "6f2675ef130a300a112286de91973805fcc5ffbc"
 uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8"
-version = "2.1.2+0"
+version = "2.1.91+0"
 
 [[deps.JuliaVariables]]
 deps = ["MLStyle", "NameResolution"]
@@ -733,9 +762,9 @@ version = "4.16.0"
 
 [[deps.LLVMExtra_jll]]
 deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"]
-git-tree-sha1 = "771bfe376249626d3ca12bcd58ba243d3f961576"
+git-tree-sha1 = "7718cf44439c676bc0ec66a87099f41015a522d6"
 uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
-version = "0.0.16+0"
+version = "0.0.16+2"
 
 [[deps.LZO_jll]]
 deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -851,9 +880,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 
 [[deps.LogExpFunctions]]
 deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"]
-git-tree-sha1 = "680e733c3a0a9cea9e935c8c2184aea6a63fa0b5"
+git-tree-sha1 = "0a1b7c2863e44523180fdb3146534e265a91870b"
 uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
-version = "0.3.21"
+version = "0.3.23"
 
 [[deps.Logging]]
 uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
@@ -890,15 +919,15 @@ version = "0.7.9"
 
 [[deps.MLJBase]]
 deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "Serialization", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
-git-tree-sha1 = "f6667db64f84c5031e3f4e48b5da80e1dd39429d"
+git-tree-sha1 = "6f3a7338e787cbf3460f035c21ee2547f71f8007"
 uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
-version = "0.21.5"
+version = "0.21.6"
 
 [[deps.MLJFlux]]
-deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "ProgressMeter", "Random", "Statistics", "Tables"]
-git-tree-sha1 = "a47257705ebca405a25320b111345a978925bcd5"
+deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"]
+git-tree-sha1 = "2ecdce4dd9214789ee1796103d29eaee7619ebd0"
 uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845"
-version = "0.2.7"
+version = "0.2.9"
 
 [[deps.MLJModelInterface]]
 deps = ["Random", "ScientificTypesBase", "StatisticalTraits"]
@@ -908,9 +937,9 @@ version = "1.8.0"
 
 [[deps.MLJModels]]
 deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
-git-tree-sha1 = "638c84dea26eb91a538afb52cef11d77556e6807"
+git-tree-sha1 = "1d445497ca058dbc0dbc7528b778707893edb969"
 uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
-version = "0.16.3"
+version = "0.16.4"
 
 [[deps.MLStyle]]
 git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8"
@@ -960,6 +989,12 @@ git-tree-sha1 = "c13304c81eec1ed3af7fc20e75fb6b26092a1102"
 uuid = "442fdcdd-2543-5da2-b0f3-8c86c306513e"
 version = "0.3.2"
 
+[[deps.Metalhead]]
+deps = ["Artifacts", "BSON", "Flux", "Functors", "LazyArtifacts", "MLUtils", "NNlib", "Random", "Statistics"]
+git-tree-sha1 = "0e95f91cc5f23610f8f270d7397f307b21e19d2b"
+uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
+version = "0.7.4"
+
 [[deps.MicroCollections]]
 deps = ["BangBang", "InitialValues", "Setfield"]
 git-tree-sha1 = "4d5917a26ca33c66c8e5ca3247bd163624d35493"
@@ -1011,15 +1046,15 @@ version = "7.8.3"
 
 [[deps.NNlib]]
 deps = ["Adapt", "ChainRulesCore", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"]
-git-tree-sha1 = "ddf38a5d9140bc8c08ea6158484a455ca3efdd2d"
+git-tree-sha1 = "33ad5a19dc6730d592d8ce91c14354d758e53b0e"
 uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
-version = "0.8.18"
+version = "0.8.19"
 
 [[deps.NNlibCUDA]]
-deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"]
-git-tree-sha1 = "b05a082b08a3af0e5c576883bc6dfb6513e7e478"
+deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics", "cuDNN"]
+git-tree-sha1 = "f94a9684394ff0d325cc12b06da7032d8be01aaf"
 uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d"
-version = "0.2.6"
+version = "0.2.7"
 
 [[deps.NPZ]]
 deps = ["FileIO", "ZipFile"]
@@ -1029,9 +1064,9 @@ version = "0.4.3"
 
 [[deps.NaNMath]]
 deps = ["OpenLibm_jll"]
-git-tree-sha1 = "a7c3d1da1189a1c2fe843a3bfa04d18d20eb3211"
+git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4"
 uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
-version = "1.0.1"
+version = "1.0.2"
 
 [[deps.NameResolution]]
 deps = ["PrettyPrint"]
@@ -1120,9 +1155,9 @@ version = "2.0.2"
 
 [[deps.Optimisers]]
 deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"]
-git-tree-sha1 = "e657acef119cc0de2a8c0762666d3b64727b053b"
+git-tree-sha1 = "e5a1825d3d53aa4ad4fb42bd4927011ad4a78c3d"
 uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
-version = "0.2.14"
+version = "0.2.15"
 
 [[deps.Opus_jll]]
 deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -1142,9 +1177,9 @@ version = "10.40.0+0"
 
 [[deps.PDMats]]
 deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
-git-tree-sha1 = "cf494dca75a69712a72b80bc48f59dcf3dea63ec"
+git-tree-sha1 = "67eae2738d63117a196f497d7db789821bce61d1"
 uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
-version = "0.11.16"
+version = "0.11.17"
 
 [[deps.PaddedViews]]
 deps = ["OffsetArrays"]
@@ -1206,9 +1241,9 @@ version = "1.3.4"
 
 [[deps.Plots]]
 deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "Preferences", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SnoopPrecompile", "SparseArrays", "Statistics", "StatsBase", "UUIDs", "UnicodeFun", "Unzip"]
-git-tree-sha1 = "8ac949bd0ebc46a44afb1fdca1094554a84b086e"
+git-tree-sha1 = "da1d3fb7183e38603fcdd2061c47979d91202c97"
 uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
-version = "1.38.5"
+version = "1.38.6"
 
 [[deps.PooledArrays]]
 deps = ["DataAPI", "Future"]
@@ -1356,9 +1391,9 @@ version = "1.1.1"
 
 [[deps.SentinelArrays]]
 deps = ["Dates", "Random"]
-git-tree-sha1 = "c02bd3c9c3fc8463d3591a62a378f90d2d8ab0f3"
+git-tree-sha1 = "77d3c4726515dca71f6d80fbb5e251088defe305"
 uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
-version = "1.3.17"
+version = "1.3.18"
 
 [[deps.Serialization]]
 uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
@@ -1422,9 +1457,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
 
 [[deps.SpecialFunctions]]
 deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
-git-tree-sha1 = "d75bda01f8c31ebb72df80a46c88b25d1c79c56d"
+git-tree-sha1 = "ef28127915f4229c971eb43f3fc075dd3fe91880"
 uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
-version = "2.1.7"
+version = "2.2.0"
 
 [[deps.SplittablesBase]]
 deps = ["Setfield", "Test"]
@@ -1440,9 +1475,9 @@ version = "0.1.1"
 
 [[deps.StaticArrays]]
 deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"]
-git-tree-sha1 = "67d3e75e8af8089ea34ce96974d5468d4a008ca6"
+git-tree-sha1 = "2d7d9e1ddadc8407ffd460e24218e37ef52dd9a3"
 uuid = "90137ffa-7385-5640-81b9-e52037218182"
-version = "1.5.15"
+version = "1.5.16"
 
 [[deps.StaticArraysCore]]
 git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a"
@@ -1473,9 +1508,9 @@ version = "0.33.21"
 
 [[deps.StatsFuns]]
 deps = ["ChainRulesCore", "HypergeometricFunctions", "InverseFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"]
-git-tree-sha1 = "ab6083f09b3e617e34a956b43e9d51b824206932"
+git-tree-sha1 = "5aa6250a781e567388f3285fb4b0f214a501b4d5"
 uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
-version = "1.1.1"
+version = "1.2.1"
 
 [[deps.Strided]]
 deps = ["LinearAlgebra", "TupleTools"]
@@ -1801,6 +1836,12 @@ git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0"
 uuid = "700de1a5-db45-46bc-99cf-38207098b444"
 version = "0.2.2"
 
+[[deps.cuDNN]]
+deps = ["CEnum", "CUDA", "CUDNN_jll"]
+git-tree-sha1 = "c0ffcb38d1e8c0bbcd3dab2559cf9c369130b2f2"
+uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
+version = "1.0.1"
+
 [[deps.fzf_jll]]
 deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
 git-tree-sha1 = "868e669ccb12ba16eaf50cb2957ee2ff61261c56"
diff --git a/notebooks/conformal.qmd b/notebooks/conformal.qmd
index a05f147e..015fbe0d 100644
--- a/notebooks/conformal.qmd
+++ b/notebooks/conformal.qmd
@@ -3,6 +3,7 @@ using CCE
 using ConformalPrediction
 using CounterfactualExplanations
 using CounterfactualExplanations.Data
+using CounterfactualExplanations.Objectives
 using Flux
 using MLJBase
 using MLJFlux
@@ -107,18 +108,62 @@ p2 = contourf(mach.model, mach.fitresult, X, y; plot_classification_loss=true, t
 plot(p1, p2, size=(800,320))
 ```
 
+### Penalizing Size
+
+```{julia}
+#| output: true
+#| echo: false
+#| label: fig-ce
+#| fig-cap: "Comparison of counterfactuals produced using different generators."
+
+opt = Descent(0.01)
+ordered_names = [
+    "Generic (γ=0.5)",
+    "Generic (γ=0.9)",
+    "Conformal (λ₂=1)",
+    "Conformal (λ₂=10)"
+]
+loss_fun = Objectives.logitbinarycrossentropy
+
+# Generators:
+generators = Dict(
+    ordered_names[1] => GenericGenerator(opt = opt, decision_threshold=0.5),
+    ordered_names[2] => GenericGenerator(opt = opt, decision_threshold=0.9),
+    ordered_names[3] => CCE.ConformalGenerator(loss=loss_fun, opt=opt, λ=[0.1,1]),
+    ordered_names[4] => CCE.ConformalGenerator(loss=loss_fun, opt=opt, λ=[0.1,10]),
+)
+
+counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen; initialization=:identity) for (name, gen) in generators])
+
+# Plots:
+plts = []
+for name ∈ ordered_names
+    ce = counterfactuals[name]
+    plt = plot(ce; title=name, colorbar=false, ticks = false, legend=false, zoom=0)
+    plts = vcat(plts..., plt)
+end
+_n = length(generators)
+img_size = 300
+plot(plts..., size=(_n * img_size,1.05*img_size), layout=(1,_n))
+```
+
+
+### Configurable Classification Loss
+
 ```{julia}
 #| output: true
 #| echo: false
 #| label: fig-ce
 #| fig-cap: "Comparison of counterfactuals produced using different generators."
 
+opt = Descent(0.01)
 ordered_names = [
     "Generic (γ=0.5)",
     "Generic (γ=0.9)",
     "Conformal (λ₂=1)",
     "Conformal (λ₂=10)"
 ]
+loss_fun = Objectives.logitbinarycrossentropy
 
 # Generators:
 generators = Dict(
diff --git a/src/ConformalGenerator.jl b/src/ConformalGenerator.jl
index 056387b1..280e483c 100644
--- a/src/ConformalGenerator.jl
+++ b/src/ConformalGenerator.jl
@@ -1,5 +1,8 @@
+using CategoricalArrays
 using CounterfactualExplanations
 using CounterfactualExplanations.Generators
+using CounterfactualExplanations.Models: binary_to_onehot
+using CounterfactualExplanations.Objectives
 using Flux
 using LinearAlgebra
 using Parameters
@@ -7,7 +10,7 @@ using SliceMap
 using Statistics
 
 mutable struct ConformalGenerator <: AbstractGradientBasedGenerator
-    loss::Union{Nothing,Symbol} # loss function
+    loss::Union{Nothing,Function} # loss function
     complexity::Function # complexity function
     λ::Union{AbstractFloat,AbstractVector} # strength of penalty
     decision_threshold::Union{Nothing,AbstractFloat}
@@ -26,13 +29,12 @@ end
 end
 
 """
-    ConformalGenerator(
-        ;
-        loss::Symbol=:logitbinarycrossentropy,
+    ConformalGenerator(;
+        loss::Union{Nothing,Function}=conformal_training_loss,
         complexity::Function=norm,
-        λ::AbstractFloat=0.1,
-        opt::Flux.Optimise.AbstractOptimiser=Flux.Optimise.Descent(),
-        Ï„::AbstractFloat=1e-5
+        λ::Union{AbstractFloat,AbstractVector}=[0.1, 1.0],
+        decision_threshold=nothing,
+        kwargs...
     )
 
 An outer constructor method that instantiates a generic generator.
@@ -43,7 +45,7 @@ generator = ConformalGenerator()
 ```
 """
 function ConformalGenerator(;
-    loss::Union{Nothing,Symbol}=nothing,
+    loss::Union{Nothing,Function}=conformal_training_loss,
     complexity::Function=norm,
     λ::Union{AbstractFloat,AbstractVector}=[0.1, 1.0],
     decision_threshold=nothing,
@@ -53,6 +55,32 @@ function ConformalGenerator(;
     ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ, params.κ, params.temp)
 end
 
+@doc raw"""
+    conformal_training_loss(counterfactual_explanation::AbstractCounterfactualExplanation; kwargs...)
+
+A configurable classification loss function for Conformal Predictors.
+"""
+function conformal_training_loss(counterfactual_explanation::AbstractCounterfactualExplanation; kwargs...)
+    conf_model = counterfactual_explanation.M.model
+    fitresult = counterfactual_explanation.M.fitresult
+    X = CounterfactualExplanations.decode_state(counterfactual_explanation)
+    y = counterfactual_explanation.target_encoded[:,:,1]
+    if counterfactual_explanation.M.likelihood == :classification_binary
+        y = binary_to_onehot(y)
+    end
+    y = permutedims(y)
+    generator = counterfactual_explanation.generator
+    loss = SliceMap.slicemap(X, dims=(1, 2)) do x
+        x = Matrix(x)
+        ConformalPrediction.classification_loss(
+            conf_model, fitresult, x, y;
+            temp=generator.temp
+        )
+    end
+    loss = mean(loss)
+    return loss
+end
+
 """
     set_size_penalty(
         generator::ConformalGenerator,
-- 
GitLab