diff --git a/Manifest.toml b/Manifest.toml
index 4531e744c4232042cd1891c52c5235f2fdd795ca..06b80abd0a09ec27e3d9f94fbf4dc29493113df6 100644
--- a/Manifest.toml
+++ b/Manifest.toml
@@ -49,20 +49,26 @@ uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97"
 version = "0.5.4"
 
 [[deps.Arpack_jll]]
-deps = ["Libdl", "OpenBLAS_jll", "Pkg"]
-git-tree-sha1 = "e214a9b9bd1b4e1b4f15b22c0994862b66af7ff7"
+deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "OpenBLAS_jll", "Pkg"]
+git-tree-sha1 = "5ba6c757e8feccf03a1554dfaf3e26b3cfc7fd5e"
 uuid = "68821587-b530-5797-8361-c406ea357684"
-version = "3.5.0+3"
+version = "3.5.1+1"
 
 [[deps.ArrayInterface]]
 deps = ["Adapt", "LinearAlgebra", "Requires", "SnoopPrecompile", "SparseArrays", "SuiteSparse"]
-git-tree-sha1 = "a89acc90c551067cd84119ff018619a1a76c6277"
+git-tree-sha1 = "38911c7737e123b28182d89027f4216cfc8a9da7"
 uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
-version = "7.2.1"
+version = "7.4.3"
 
 [[deps.Artifacts]]
 uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
 
+[[deps.Atomix]]
+deps = ["UnsafeAtomics"]
+git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be"
+uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
+version = "0.1.0"
+
 [[deps.AxisAlgorithms]]
 deps = ["LinearAlgebra", "Random", "SparseArrays", "WoodburyMatrices"]
 git-tree-sha1 = "66771c8d21c8ff5e3a93379480a2307ac36863f7"
@@ -76,9 +82,9 @@ uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
 version = "0.4.2"
 
 [[deps.BSON]]
-git-tree-sha1 = "86e9781ac28f4e80e9b98f7f96eae21891332ac2"
+git-tree-sha1 = "2208958832d6e1b59e49f53697483a84ca8d664e"
 uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
-version = "0.3.6"
+version = "0.3.7"
 
 [[deps.BangBang]]
 deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"]
@@ -94,11 +100,10 @@ git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e"
 uuid = "9718e550-a3fa-408a-8086-8db961cd8217"
 version = "0.1.1"
 
-[[deps.BinaryProvider]]
-deps = ["Libdl", "Logging", "SHA"]
-git-tree-sha1 = "ecdec412a9abc8db54c0efc5548c64dfce072058"
-uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
-version = "0.5.10"
+[[deps.BitFlags]]
+git-tree-sha1 = "43b1a4a8f797c1cddadf60499a8a077d4af2cd2d"
+uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35"
+version = "0.1.7"
 
 [[deps.BufferedStreams]]
 git-tree-sha1 = "bb065b14d7f941b8617bc323063dbe79f55d16ea"
@@ -123,34 +128,34 @@ uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
 version = "0.10.9"
 
 [[deps.CUDA]]
-deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions"]
-git-tree-sha1 = "edff14c60784c8f7191a62a23b15a421185bc8a8"
+deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "UnsafeAtomicsLLVM"]
+git-tree-sha1 = "8547829ee0da896ce48a24b8d2f4bf929cf3e45e"
 uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
-version = "4.0.1"
+version = "4.1.4"
 
 [[deps.CUDA_Driver_jll]]
 deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"]
-git-tree-sha1 = "75d7896d1ec079ef10d3aee8f3668c11354c03a1"
+git-tree-sha1 = "498f45593f6ddc0adff64a9310bb6710e851781b"
 uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc"
-version = "0.2.0+0"
+version = "0.5.0+1"
 
 [[deps.CUDA_Runtime_Discovery]]
 deps = ["Libdl"]
-git-tree-sha1 = "58dd8ec29f54f08c04b052d2c2fa6760b4f4b3a4"
+git-tree-sha1 = "bcc4a23cbbd99c8535a5318455dcf0f2546ec536"
 uuid = "1af6417a-86b4-443c-805f-a4643ffb695f"
-version = "0.1.1"
+version = "0.2.2"
 
 [[deps.CUDA_Runtime_jll]]
-deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"]
-git-tree-sha1 = "d3e6ccd30f84936c1a3a53d622d85d7d3f9b9486"
+deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
+git-tree-sha1 = "81eed046f28a0cdd0dc1f61d00a49061b7cc9433"
 uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
-version = "0.2.3+2"
+version = "0.5.0+2"
 
 [[deps.CUDNN_jll]]
-deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"]
-git-tree-sha1 = "aafe89dfde54011993c4029d3be3e037fd63db07"
+deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
+git-tree-sha1 = "2918fbffb50e3b7a0b9127617587afa76d4276e8"
 uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645"
-version = "8.6.0+5"
+version = "8.8.1+0"
 
 [[deps.Cairo_jll]]
 deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"]
@@ -214,9 +219,9 @@ version = "0.14.4"
 
 [[deps.CodeTracking]]
 deps = ["InteractiveUtils", "UUIDs"]
-git-tree-sha1 = "d57c99cc7e637165c81b30eb268eabe156a45c49"
+git-tree-sha1 = "d730914ef30a06732bdd9f763f6cc32e92ffbff1"
 uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
-version = "1.2.2"
+version = "1.3.1"
 
 [[deps.CodecZlib]]
 deps = ["TranscodingStreams", "Zlib_jll"]
@@ -320,10 +325,10 @@ uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
 version = "1.14.0"
 
 [[deps.DataDeps]]
-deps = ["BinaryProvider", "HTTP", "Libdl", "Reexport", "SHA", "p7zip_jll"]
-git-tree-sha1 = "e299d8267135ef2f9c941a764006697082c1e7e8"
+deps = ["HTTP", "Libdl", "Reexport", "SHA", "p7zip_jll"]
+git-tree-sha1 = "bc0a264d3e7b3eeb0b6fc9f6481f970697f29805"
 uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
-version = "0.7.8"
+version = "0.7.10"
 
 [[deps.DataFrames]]
 deps = ["Compat", "DataAPI", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SnoopPrecompile", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
@@ -381,9 +386,9 @@ version = "1.13.0"
 
 [[deps.Distances]]
 deps = ["LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI"]
-git-tree-sha1 = "3258d0659f812acde79e8a74b11f17ac06d0ca04"
+git-tree-sha1 = "49eba9ad9f7ead780bfb7ee319f962c811c6d3b2"
 uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
-version = "0.10.7"
+version = "0.10.8"
 
 [[deps.Distributed]]
 deps = ["Random", "Serialization", "Sockets"]
@@ -391,9 +396,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
 
 [[deps.Distributions]]
 deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"]
-git-tree-sha1 = "da9e1a9058f8d3eec3a8c9fe4faacfb89180066b"
+git-tree-sha1 = "13027f188d26206b9e7b863036f87d2f2e7d013a"
 uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
-version = "0.25.86"
+version = "0.25.87"
 
 [[deps.DocStringExtensions]]
 deps = ["LibGit2"]
@@ -476,15 +481,15 @@ uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
 
 [[deps.FillArrays]]
 deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"]
-git-tree-sha1 = "3b245d1e50466ca0c9529e2033a3c92387c59c2f"
+git-tree-sha1 = "7072f1e3e5a8be51d525d64f63d3ec1287ff2790"
 uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
-version = "0.13.9"
+version = "0.13.11"
 
 [[deps.FiniteDiff]]
 deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays", "StaticArrays"]
-git-tree-sha1 = "ed1b56934a2f7a65035976985da71b6a65b4f2cf"
+git-tree-sha1 = "03fcb1c42ec905d15b305359603888ec3e65f886"
 uuid = "6a86dc24-6348-571c-b903-95158fe2bd41"
-version = "2.18.0"
+version = "2.19.0"
 
 [[deps.FixedPointNumbers]]
 deps = ["Statistics"]
@@ -493,10 +498,10 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
 version = "0.8.4"
 
 [[deps.Flux]]
-deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Zygote"]
-git-tree-sha1 = "4ff3a1d7b0dd38f2fc38e813bc801f817639c1f2"
+deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Zygote", "cuDNN"]
+git-tree-sha1 = "3f6f32ec0bfd80be0cb65907cf74ec796a632012"
 uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
-version = "0.13.13"
+version = "0.13.15"
 
 [[deps.FoldsThreads]]
 deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"]
@@ -541,9 +546,9 @@ version = "1.1.3"
 
 [[deps.Functors]]
 deps = ["LinearAlgebra"]
-git-tree-sha1 = "7ed0833a55979d3d2658a60b901469748a6b9a7c"
+git-tree-sha1 = "478f8c3145bb91d82c2cf20433e8c1b30df454cc"
 uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
-version = "0.4.3"
+version = "0.4.4"
 
 [[deps.Future]]
 deps = ["Random"]
@@ -557,9 +562,9 @@ version = "3.3.8+0"
 
 [[deps.GPUArrays]]
 deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"]
-git-tree-sha1 = "7a2e790b1e2e6f648cfb25c4500c5de1f7b375ef"
+git-tree-sha1 = "9ade6983c3dbbd492cf5729f865fe030d1541463"
 uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
-version = "8.6.5"
+version = "8.6.6"
 
 [[deps.GPUArraysCore]]
 deps = ["Adapt"]
@@ -568,22 +573,22 @@ uuid = "46192b85-c4d5-4398-a991-12ede77f4527"
 version = "0.1.4"
 
 [[deps.GPUCompiler]]
-deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"]
-git-tree-sha1 = "19d693666a304e8c371798f4900f7435558c7cde"
+deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"]
+git-tree-sha1 = "e9a9173cd77e16509cdf9c1663fda19b22a518b7"
 uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
-version = "0.17.3"
+version = "0.19.3"
 
 [[deps.GR]]
 deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Pkg", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "UUIDs", "p7zip_jll"]
-git-tree-sha1 = "4423d87dc2d3201f3f1768a29e807ddc8cc867ef"
+git-tree-sha1 = "011a22022ed2fb0352a9bded0fa9d3793a8db362"
 uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
-version = "0.71.8"
+version = "0.72.1"
 
 [[deps.GR_jll]]
 deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt5Base_jll", "Zlib_jll", "libpng_jll"]
-git-tree-sha1 = "3657eb348d44575cc5560c80d7e55b812ff6ffe1"
+git-tree-sha1 = "7ea8ead860c85b27e83d198ea54bb2f387db9fc3"
 uuid = "d2c73de3-f751-5644-a686-071e5b155ba9"
-version = "0.71.8+0"
+version = "0.72.1+1"
 
 [[deps.GZip]]
 deps = ["Libdl"]
@@ -604,9 +609,9 @@ uuid = "7746bdde-850d-59dc-9ae8-88ece973131d"
 version = "2.74.0+2"
 
 [[deps.Glob]]
-git-tree-sha1 = "4df9f7e06108728ebf00a0a11edee4b29a482bb2"
+git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496"
 uuid = "c27321d9-0574-5035-807b-f59d2c89b15c"
-version = "1.3.0"
+version = "1.3.1"
 
 [[deps.Graphics]]
 deps = ["Colors", "LinearAlgebra", "NaNMath"]
@@ -638,10 +643,10 @@ uuid = "0234f1f7-429e-5d53-9886-15a909be8d59"
 version = "1.12.2+2"
 
 [[deps.HTTP]]
-deps = ["Base64", "Dates", "IniFile", "Logging", "MbedTLS", "NetworkOptions", "Sockets", "URIs"]
-git-tree-sha1 = "0fa77022fe4b511826b39c894c90daf5fce3334a"
+deps = ["Base64", "CodecZlib", "Dates", "IniFile", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"]
+git-tree-sha1 = "37e4657cd56b11abe3d10cd4a1ec5fbdb4180263"
 uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3"
-version = "0.9.17"
+version = "1.7.4"
 
 [[deps.HarfBuzz_jll]]
 deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"]
@@ -656,16 +661,16 @@ uuid = "eafb193a-b7ab-5a9e-9068-77385905fa72"
 version = "0.5.2"
 
 [[deps.HypergeometricFunctions]]
-deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions", "Test"]
-git-tree-sha1 = "709d864e3ed6e3545230601f94e11ebc65994641"
+deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"]
+git-tree-sha1 = "432b5b03176f8182bd6841fbfc42c718506a2d5f"
 uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a"
-version = "0.3.11"
+version = "0.3.15"
 
 [[deps.IRTools]]
 deps = ["InteractiveUtils", "MacroTools", "Test"]
-git-tree-sha1 = "2af2fe19f0d5799311a6491267a14817ad9fbd20"
+git-tree-sha1 = "0ade27f0c49cebd8db2523c4eeccf779407cf12c"
 uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
-version = "0.4.8"
+version = "0.4.9"
 
 [[deps.ImageBase]]
 deps = ["ImageCore", "Reexport"]
@@ -708,9 +713,9 @@ version = "1.4.0"
 
 [[deps.IntelOpenMP_jll]]
 deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
-git-tree-sha1 = "d979e54b71da82f3a65b62553da4fc3d18c9004c"
+git-tree-sha1 = "0cb9352ef2e01574eeebdb102948a58740dcaf83"
 uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0"
-version = "2018.0.3+2"
+version = "2023.1.0+0"
 
 [[deps.InteractiveUtils]]
 deps = ["Markdown"]
@@ -769,9 +774,9 @@ version = "1.4.1"
 
 [[deps.JSON]]
 deps = ["Dates", "Mmap", "Parsers", "Unicode"]
-git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e"
+git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a"
 uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
-version = "0.21.3"
+version = "0.21.4"
 
 [[deps.JSON3]]
 deps = ["Dates", "Mmap", "Parsers", "SnoopPrecompile", "StructTypes", "UUIDs"]
@@ -780,7 +785,7 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
 version = "1.12.0"
 
 [[deps.JointEnergyModels]]
-deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"]
+deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "MLUtils", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"]
 path = "../JointEnergyModels.jl"
 uuid = "48c56d24-211d-4463-bbc0-7a701b291131"
 version = "0.1.0"
@@ -802,6 +807,12 @@ git-tree-sha1 = "4aeebbfcf0615641ec4b0782b73b638eeeabd62e"
 uuid = "5cadff95-7770-533d-a838-a1bf817ee6e0"
 version = "0.3.0"
 
+[[deps.KernelAbstractions]]
+deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "SnoopPrecompile", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"]
+git-tree-sha1 = "976231af02176082fb266a9f96a59da51fcacf20"
+uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
+version = "0.9.2"
+
 [[deps.KernelDensity]]
 deps = ["Distributions", "DocStringExtensions", "FFTW", "Interpolations", "StatsBase"]
 git-tree-sha1 = "9816b296736292a80b9a3200eb7fbb57aaa3917a"
@@ -822,15 +833,15 @@ version = "3.0.0+1"
 
 [[deps.LLVM]]
 deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
-git-tree-sha1 = "1c614dfbecbaee4897b506bba2b432bf0d21f2ed"
+git-tree-sha1 = "a8960cae30b42b66dd41808beb76490519f6f9e2"
 uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
-version = "4.17.0"
+version = "5.0.0"
 
 [[deps.LLVMExtra_jll]]
 deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
-git-tree-sha1 = "e46e3a40daddcbe851f86db0ec4a4a3d4badf800"
+git-tree-sha1 = "09b7505cc0b1cee87e5d4a26eea61d2e1b0dcd35"
 uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
-version = "0.0.19+0"
+version = "0.0.21+0"
 
 [[deps.LZO_jll]]
 deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -953,11 +964,17 @@ version = "0.3.23"
 [[deps.Logging]]
 uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
 
+[[deps.LoggingExtras]]
+deps = ["Dates", "Logging"]
+git-tree-sha1 = "cedb76b37bc5a6c702ade66be44f831fa23c681e"
+uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36"
+version = "1.0.0"
+
 [[deps.LossFunctions]]
-deps = ["InteractiveUtils", "Markdown", "RecipesBase"]
-git-tree-sha1 = "53cd63a12f06a43eef6f4aafb910ac755c122be7"
+deps = ["CategoricalArrays", "Markdown"]
+git-tree-sha1 = "d4c7ff8c7281943371e1725000fd538a699024d0"
 uuid = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
-version = "0.8.0"
+version = "0.9.0"
 
 [[deps.LsqFit]]
 deps = ["Distributions", "ForwardDiff", "LinearAlgebra", "NLSolversBase", "OptimBase", "Random", "StatsBase"]
@@ -985,9 +1002,9 @@ version = "0.7.9"
 
 [[deps.MLJBase]]
 deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "Serialization", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
-git-tree-sha1 = "6f3a7338e787cbf3460f035c21ee2547f71f8007"
+git-tree-sha1 = "e99e84c7ca1696b38e2a8823f53ac9cc775599dd"
 uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
-version = "0.21.6"
+version = "0.21.9"
 
 [[deps.MLJEnsembles]]
 deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJBase", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatsBase"]
@@ -1009,9 +1026,9 @@ version = "1.8.0"
 
 [[deps.MLJModels]]
 deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
-git-tree-sha1 = "1d445497ca058dbc0dbc7528b778707893edb969"
+git-tree-sha1 = "21acf47dc53ccc3d68e38ac7629756cd09b599f5"
 uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
-version = "0.16.4"
+version = "0.16.6"
 
 [[deps.MLStyle]]
 git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8"
@@ -1122,10 +1139,10 @@ uuid = "d41bc354-129a-5804-8e4c-c37616107c6c"
 version = "7.8.3"
 
 [[deps.NNlib]]
-deps = ["Adapt", "ChainRulesCore", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"]
-git-tree-sha1 = "33ad5a19dc6730d592d8ce91c14354d758e53b0e"
+deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"]
+git-tree-sha1 = "99e6dbb50d8a96702dc60954569e9fe7291cc55d"
 uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
-version = "0.8.19"
+version = "0.8.20"
 
 [[deps.NNlibCUDA]]
 deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics", "cuDNN"]
@@ -1164,9 +1181,9 @@ version = "0.3.5"
 
 [[deps.NearestNeighborModels]]
 deps = ["Distances", "FillArrays", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "NearestNeighbors", "Statistics", "StatsBase", "Tables"]
-git-tree-sha1 = "727b8f1c3f9fec6b1a805ba9bef72c73758eda02"
+git-tree-sha1 = "c2179f9d8de066c481b889a1426068c5831bb10b"
 uuid = "636a865e-7cf4-491e-846c-de09b730eb36"
-version = "0.2.1"
+version = "0.2.2"
 
 [[deps.NearestNeighbors]]
 deps = ["Distances", "StaticArrays"]
@@ -1211,6 +1228,12 @@ deps = ["Artifacts", "Libdl"]
 uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
 version = "0.8.1+0"
 
+[[deps.OpenSSL]]
+deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"]
+git-tree-sha1 = "5b3e170ea0724f1e3ed6018c5b006c190f80e87d"
+uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c"
+version = "1.3.5"
+
 [[deps.OpenSSL_jll]]
 deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
 git-tree-sha1 = "9ff31d101d987eb9d66bd8b176ac7c277beccd09"
@@ -1231,9 +1254,9 @@ version = "2.0.2"
 
 [[deps.Optimisers]]
 deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"]
-git-tree-sha1 = "e5a1825d3d53aa4ad4fb42bd4927011ad4a78c3d"
+git-tree-sha1 = "6a01f65dd8583dee82eecc2a19b0ff21521aa749"
 uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
-version = "0.2.15"
+version = "0.2.18"
 
 [[deps.Opus_jll]]
 deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -1242,9 +1265,9 @@ uuid = "91d4177d-7536-5919-b921-800302f37372"
 version = "1.3.2+0"
 
 [[deps.OrderedCollections]]
-git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
+git-tree-sha1 = "d321bf2de576bf25ec4d3e4360faca399afca282"
 uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
-version = "1.4.1"
+version = "1.6.0"
 
 [[deps.PCRE2_jll]]
 deps = ["Artifacts", "Libdl"]
@@ -1259,9 +1282,9 @@ version = "0.11.17"
 
 [[deps.PaddedViews]]
 deps = ["OffsetArrays"]
-git-tree-sha1 = "03a7a85b76381a3d04c7a1656039197e70eda03d"
+git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f"
 uuid = "5432bcbf-9aad-5242-b902-cca2824c8663"
-version = "0.5.11"
+version = "0.5.12"
 
 [[deps.Parameters]]
 deps = ["OrderedCollections", "UnPack"]
@@ -1299,9 +1322,9 @@ version = "1.8.0"
 
 [[deps.PkgTemplates]]
 deps = ["Dates", "InteractiveUtils", "LibGit2", "Mocking", "Mustache", "Parameters", "Pkg", "REPL", "UUIDs"]
-git-tree-sha1 = "e93643cc634d7551f68b63739d205d22216cec7c"
+git-tree-sha1 = "c0f12580abb41d7d11c1c7c65a1ff410f84c61e3"
 uuid = "14b8a8f1-9102-5b29-a752-f990bacb7fe1"
-version = "0.7.32"
+version = "0.7.33"
 
 [[deps.PlotThemes]]
 deps = ["PlotUtils", "Statistics"]
@@ -1317,9 +1340,9 @@ version = "1.3.4"
 
 [[deps.Plots]]
 deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "Preferences", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SnoopPrecompile", "SparseArrays", "Statistics", "StatsBase", "UUIDs", "UnicodeFun", "Unzip"]
-git-tree-sha1 = "da1d3fb7183e38603fcdd2061c47979d91202c97"
+git-tree-sha1 = "5434b0ee344eaf2854de251f326df8720f6a7b55"
 uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
-version = "1.38.6"
+version = "1.38.10"
 
 [[deps.PooledArrays]]
 deps = ["DataAPI", "Future"]
@@ -1501,6 +1524,11 @@ git-tree-sha1 = "91eddf657aca81df9ae6ceb20b959ae5653ad1de"
 uuid = "992d4aef-0814-514b-bc4d-f2e9a6c4116f"
 version = "1.0.3"
 
+[[deps.SimpleBufferStream]]
+git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1"
+uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7"
+version = "1.1.0"
+
 [[deps.SimpleTraits]]
 deps = ["InteractiveUtils", "MacroTools"]
 git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231"
@@ -1552,9 +1580,9 @@ version = "0.1.1"
 
 [[deps.StaticArrays]]
 deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"]
-git-tree-sha1 = "6aa098ef1012364f2ede6b17bf358c7f1fbe90d4"
+git-tree-sha1 = "63e84b7fdf5021026d0f17f76af7c57772313d99"
 uuid = "90137ffa-7385-5640-81b9-e52037218182"
-version = "1.5.17"
+version = "1.5.21"
 
 [[deps.StaticArraysCore]]
 git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a"
@@ -1573,9 +1601,9 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
 
 [[deps.StatsAPI]]
 deps = ["LinearAlgebra"]
-git-tree-sha1 = "f9af7f195fb13589dd2e2d57fdb401717d2eb1f6"
+git-tree-sha1 = "45a7769a04a3cf80da1c1c7c60caf932e6f4c9f7"
 uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
-version = "1.5.0"
+version = "1.6.0"
 
 [[deps.StatsBase]]
 deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
@@ -1686,9 +1714,9 @@ version = "0.2.20"
 
 [[deps.TranscodingStreams]]
 deps = ["Random", "Test"]
-git-tree-sha1 = "94f38103c984f89cf77c402f2a68dbd870f8165f"
+git-tree-sha1 = "0b829474fed270a4b0ab07117dce9b9a2fa7581a"
 uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
-version = "0.9.11"
+version = "0.9.12"
 
 [[deps.Transducers]]
 deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"]
@@ -1732,14 +1760,25 @@ version = "0.4.1"
 
 [[deps.UnicodePlots]]
 deps = ["ColorSchemes", "ColorTypes", "Contour", "Crayons", "Dates", "LinearAlgebra", "MarchingCubes", "NaNMath", "Printf", "Requires", "SnoopPrecompile", "SparseArrays", "StaticArrays", "StatsBase"]
-git-tree-sha1 = "a5bcfc23e352f499a1a46f428d0d3d7fb9e4fc11"
+git-tree-sha1 = "2825e58f6ec3cab889dfa2c824f8d89b9f7ee731"
 uuid = "b8865327-cd53-5732-bb35-84acbb429228"
-version = "3.4.1"
+version = "3.5.1"
+
+[[deps.UnsafeAtomics]]
+git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
+uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
+version = "0.2.1"
+
+[[deps.UnsafeAtomicsLLVM]]
+deps = ["LLVM", "UnsafeAtomics"]
+git-tree-sha1 = "ea37e6066bf194ab78f4e747f5245261f17a7175"
+uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
+version = "0.1.2"
 
 [[deps.Unzip]]
-git-tree-sha1 = "34db80951901073501137bdbc3d5a8e7bbd06670"
+git-tree-sha1 = "ca0969166a028236229f63514992fc073799bb78"
 uuid = "41fe7b60-77ed-43a1-b4f0-825fd5a5650d"
-version = "0.1.2"
+version = "0.2.0"
 
 [[deps.Wayland_jll]]
 deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg", "XML2_jll"]
@@ -1927,27 +1966,27 @@ version = "1.2.12+3"
 
 [[deps.Zstd_jll]]
 deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "c6edfe154ad7b313c01aceca188c05c835c67360"
+git-tree-sha1 = "49ce682769cd5de6c72dcf1b94ed7790cd08974c"
 uuid = "3161d3a3-bdf6-5164-811a-617609db77b4"
-version = "1.5.4+0"
+version = "1.5.5+0"
 
 [[deps.Zygote]]
 deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "Requires", "SnoopPrecompile", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"]
-git-tree-sha1 = "4df8f470806a45a8630ac8f597304821dc8e8838"
+git-tree-sha1 = "987ae5554ca90e837594a0f30325eeb5e7303d1e"
 uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.59"
+version = "0.6.60"
 
 [[deps.ZygoteRules]]
-deps = ["MacroTools"]
-git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0"
+deps = ["ChainRulesCore", "MacroTools"]
+git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d"
 uuid = "700de1a5-db45-46bc-99cf-38207098b444"
-version = "0.2.2"
+version = "0.2.3"
 
 [[deps.cuDNN]]
 deps = ["CEnum", "CUDA", "CUDNN_jll"]
-git-tree-sha1 = "c0ffcb38d1e8c0bbcd3dab2559cf9c369130b2f2"
+git-tree-sha1 = "3aa15aba7aad5be8b9b3c1b77a9b81e3e1357280"
 uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
-version = "1.0.1"
+version = "1.0.2"
 
 [[deps.fzf_jll]]
 deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml
index 6dbe102ca9eb0869445e2279efae73b9f686cf9d..400facb7be5d41186be3476bfe291d7420d08ee1 100644
--- a/notebooks/Manifest.toml
+++ b/notebooks/Manifest.toml
@@ -2,7 +2,7 @@
 
 julia_version = "1.8.5"
 manifest_format = "2.0"
-project_hash = "bb24fa6d048fab99674a941d85c45a034b033aae"
+project_hash = "b4b125f21013c1ac841e2bb761cd8922630f9f03"
 
 [[deps.AbstractFFTs]]
 deps = ["ChainRulesCore", "LinearAlgebra"]
@@ -156,9 +156,9 @@ version = "0.10.9"
 
 [[deps.CUDA]]
 deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "UnsafeAtomicsLLVM"]
-git-tree-sha1 = "6591ddc73adb429b9d97145c8197a0ac81664ab4"
+git-tree-sha1 = "8547829ee0da896ce48a24b8d2f4bf929cf3e45e"
 uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
-version = "4.1.3"
+version = "4.1.4"
 
 [[deps.CUDA_Driver_jll]]
 deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"]
@@ -168,9 +168,9 @@ version = "0.5.0+1"
 
 [[deps.CUDA_Runtime_Discovery]]
 deps = ["Libdl"]
-git-tree-sha1 = "6c8fceaaa6850dea627288ac3bb86fdcdf05e326"
+git-tree-sha1 = "bcc4a23cbbd99c8535a5318455dcf0f2546ec536"
 uuid = "1af6417a-86b4-443c-805f-a4643ffb695f"
-version = "0.2.0"
+version = "0.2.2"
 
 [[deps.CUDA_Runtime_jll]]
 deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
@@ -275,9 +275,9 @@ version = "0.14.4"
 
 [[deps.CodeTracking]]
 deps = ["InteractiveUtils", "UUIDs"]
-git-tree-sha1 = "0683f086e2ef8e2fdacd3f246b35c59e7088b283"
+git-tree-sha1 = "d730914ef30a06732bdd9f763f6cc32e92ffbff1"
 uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
-version = "1.3.0"
+version = "1.3.1"
 
 [[deps.CodecZlib]]
 deps = ["TranscodingStreams", "Zlib_jll"]
@@ -475,9 +475,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
 
 [[deps.Distributions]]
 deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"]
-git-tree-sha1 = "da9e1a9058f8d3eec3a8c9fe4faacfb89180066b"
+git-tree-sha1 = "13027f188d26206b9e7b863036f87d2f2e7d013a"
 uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
-version = "0.25.86"
+version = "0.25.87"
 
 [[deps.DocStringExtensions]]
 deps = ["LibGit2"]
@@ -601,9 +601,9 @@ version = "0.8.4"
 
 [[deps.Flux]]
 deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Zygote", "cuDNN"]
-git-tree-sha1 = "e657a9aad824de4211606f113edd0b50d5e1f6db"
+git-tree-sha1 = "3f6f32ec0bfd80be0cb65907cf74ec796a632012"
 uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
-version = "0.13.14"
+version = "0.13.15"
 
 [[deps.FoldsThreads]]
 deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"]
@@ -694,21 +694,21 @@ version = "0.1.4"
 
 [[deps.GPUCompiler]]
 deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"]
-git-tree-sha1 = "590d394bad1055b798b2f9b308327ba871b7badf"
+git-tree-sha1 = "e9a9173cd77e16509cdf9c1663fda19b22a518b7"
 uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
-version = "0.19.0"
+version = "0.19.3"
 
 [[deps.GR]]
 deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Pkg", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "UUIDs", "p7zip_jll"]
-git-tree-sha1 = "4423d87dc2d3201f3f1768a29e807ddc8cc867ef"
+git-tree-sha1 = "011a22022ed2fb0352a9bded0fa9d3793a8db362"
 uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
-version = "0.71.8"
+version = "0.72.1"
 
 [[deps.GR_jll]]
 deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt5Base_jll", "Zlib_jll", "libpng_jll"]
-git-tree-sha1 = "3657eb348d44575cc5560c80d7e55b812ff6ffe1"
+git-tree-sha1 = "7ea8ead860c85b27e83d198ea54bb2f387db9fc3"
 uuid = "d2c73de3-f751-5644-a686-071e5b155ba9"
-version = "0.71.8+0"
+version = "0.72.1+1"
 
 [[deps.GZip]]
 deps = ["Libdl"]
@@ -812,9 +812,9 @@ version = "0.5.2"
 
 [[deps.HypergeometricFunctions]]
 deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"]
-git-tree-sha1 = "d926e9c297ef4607866e8ef5df41cde1a642917f"
+git-tree-sha1 = "432b5b03176f8182bd6841fbfc42c718506a2d5f"
 uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a"
-version = "0.3.14"
+version = "0.3.15"
 
 [[deps.IRTools]]
 deps = ["InteractiveUtils", "MacroTools", "Test"]
@@ -963,9 +963,9 @@ version = "0.1.5"
 
 [[deps.IntelOpenMP_jll]]
 deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
-git-tree-sha1 = "d979e54b71da82f3a65b62553da4fc3d18c9004c"
+git-tree-sha1 = "0cb9352ef2e01574eeebdb102948a58740dcaf83"
 uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0"
-version = "2018.0.3+2"
+version = "2023.1.0+0"
 
 [[deps.InteractiveUtils]]
 deps = ["Markdown"]
@@ -1041,9 +1041,9 @@ version = "1.4.1"
 
 [[deps.JSON]]
 deps = ["Dates", "Mmap", "Parsers", "Unicode"]
-git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e"
+git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a"
 uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
-version = "0.21.3"
+version = "0.21.4"
 
 [[deps.JSON3]]
 deps = ["Dates", "Mmap", "Parsers", "SnoopPrecompile", "StructTypes", "UUIDs"]
@@ -1052,7 +1052,7 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
 version = "1.12.0"
 
 [[deps.JointEnergyModels]]
-deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"]
+deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "MLUtils", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"]
 path = "../../JointEnergyModels.jl"
 uuid = "48c56d24-211d-4463-bbc0-7a701b291131"
 version = "0.1.0"
@@ -1082,9 +1082,9 @@ version = "0.3.0"
 
 [[deps.KernelAbstractions]]
 deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "SnoopPrecompile", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"]
-git-tree-sha1 = "350a880e80004f4d5d82a17f737d8fcdc56c3462"
+git-tree-sha1 = "976231af02176082fb266a9f96a59da51fcacf20"
 uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
-version = "0.9.1"
+version = "0.9.2"
 
 [[deps.KernelDensity]]
 deps = ["Distributions", "DocStringExtensions", "FFTW", "Interpolations", "StatsBase"]
@@ -1250,10 +1250,10 @@ uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36"
 version = "1.0.0"
 
 [[deps.LossFunctions]]
-deps = ["InteractiveUtils", "Markdown", "RecipesBase"]
-git-tree-sha1 = "f27330f931944ecee340f004302db724c1985955"
+deps = ["CategoricalArrays", "Markdown"]
+git-tree-sha1 = "d4c7ff8c7281943371e1725000fd538a699024d0"
 uuid = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
-version = "0.8.1"
+version = "0.9.0"
 
 [[deps.LsqFit]]
 deps = ["Distributions", "ForwardDiff", "LinearAlgebra", "NLSolversBase", "OptimBase", "Random", "StatsBase"]
@@ -1281,9 +1281,9 @@ version = "0.7.9"
 
 [[deps.MLJBase]]
 deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "Serialization", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
-git-tree-sha1 = "37a311b0cd581764fc460f6632e6219dc32f9427"
+git-tree-sha1 = "e99e84c7ca1696b38e2a8823f53ac9cc775599dd"
 uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
-version = "0.21.8"
+version = "0.21.9"
 
 [[deps.MLJEnsembles]]
 deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJBase", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatsBase"]
@@ -1305,9 +1305,9 @@ version = "1.8.0"
 
 [[deps.MLJModels]]
 deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
-git-tree-sha1 = "865558dcdb963789ba82651b990d0fb9c5e8dd59"
+git-tree-sha1 = "21acf47dc53ccc3d68e38ac7629756cd09b599f5"
 uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
-version = "0.16.5"
+version = "0.16.6"
 
 [[deps.MLStyle]]
 git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8"
@@ -1360,9 +1360,9 @@ version = "1.2.0"
 
 [[deps.MathTeXEngine]]
 deps = ["AbstractTrees", "Automa", "DataStructures", "FreeTypeAbstraction", "GeometryBasics", "LaTeXStrings", "REPL", "RelocatableFolders", "Test", "UnicodeFun"]
-git-tree-sha1 = "64890e1e8087b71c03bd6b8af99b49c805b2a78d"
+git-tree-sha1 = "8f52dbaa1351ce4cb847d95568cb29e62a307d93"
 uuid = "0a4f8689-d25c-4efe-a92b-7142dfc1aa53"
-version = "0.5.5"
+version = "0.5.6"
 
 [[deps.MbedTLS]]
 deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "Random", "Sockets"]
@@ -1453,10 +1453,10 @@ uuid = "d41bc354-129a-5804-8e4c-c37616107c6c"
 version = "7.8.3"
 
 [[deps.NNlib]]
-deps = ["Adapt", "ChainRulesCore", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"]
-git-tree-sha1 = "33ad5a19dc6730d592d8ce91c14354d758e53b0e"
+deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"]
+git-tree-sha1 = "99e6dbb50d8a96702dc60954569e9fe7291cc55d"
 uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
-version = "0.8.19"
+version = "0.8.20"
 
 [[deps.NNlibCUDA]]
 deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics", "cuDNN"]
@@ -1562,9 +1562,9 @@ version = "0.8.1+0"
 
 [[deps.OpenSSL]]
 deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"]
-git-tree-sha1 = "6503b77492fd7fcb9379bf73cd31035670e3c509"
+git-tree-sha1 = "5b3e170ea0724f1e3ed6018c5b006c190f80e87d"
 uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c"
-version = "1.3.3"
+version = "1.3.5"
 
 [[deps.OpenSSL_jll]]
 deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -1586,9 +1586,9 @@ version = "2.0.2"
 
 [[deps.Optimisers]]
 deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"]
-git-tree-sha1 = "4b214125921ec010160ddb39931885e0a6585639"
+git-tree-sha1 = "6a01f65dd8583dee82eecc2a19b0ff21521aa749"
 uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
-version = "0.2.17"
+version = "0.2.18"
 
 [[deps.Opus_jll]]
 deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -1626,9 +1626,9 @@ version = "0.5.0"
 
 [[deps.PaddedViews]]
 deps = ["OffsetArrays"]
-git-tree-sha1 = "03a7a85b76381a3d04c7a1656039197e70eda03d"
+git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f"
 uuid = "5432bcbf-9aad-5242-b902-cca2824c8663"
-version = "0.5.11"
+version = "0.5.12"
 
 [[deps.Pango_jll]]
 deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "FriBidi_jll", "Glib_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg"]
@@ -1672,9 +1672,9 @@ version = "1.8.0"
 
 [[deps.PkgTemplates]]
 deps = ["Dates", "InteractiveUtils", "LibGit2", "Mocking", "Mustache", "Parameters", "Pkg", "REPL", "UUIDs"]
-git-tree-sha1 = "e93643cc634d7551f68b63739d205d22216cec7c"
+git-tree-sha1 = "c0f12580abb41d7d11c1c7c65a1ff410f84c61e3"
 uuid = "14b8a8f1-9102-5b29-a752-f990bacb7fe1"
-version = "0.7.32"
+version = "0.7.33"
 
 [[deps.PkgVersion]]
 deps = ["Pkg"]
@@ -1696,9 +1696,9 @@ version = "1.3.4"
 
 [[deps.Plots]]
 deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "Preferences", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SnoopPrecompile", "SparseArrays", "Statistics", "StatsBase", "UUIDs", "UnicodeFun", "Unzip"]
-git-tree-sha1 = "f49a45a239e13333b8b936120fe6d793fe58a972"
+git-tree-sha1 = "5434b0ee344eaf2854de251f326df8720f6a7b55"
 uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
-version = "1.38.8"
+version = "1.38.10"
 
 [[deps.PolygonOps]]
 git-tree-sha1 = "77b3d3605fc1cd0b42d95eba87dfcd2bf67d5ff6"
@@ -2023,9 +2023,9 @@ version = "0.1.1"
 
 [[deps.StaticArrays]]
 deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"]
-git-tree-sha1 = "b8d897fe7fa688e93aef573711cb207c08c9e11e"
+git-tree-sha1 = "63e84b7fdf5021026d0f17f76af7c57772313d99"
 uuid = "90137ffa-7385-5640-81b9-e52037218182"
-version = "1.5.19"
+version = "1.5.21"
 
 [[deps.StaticArraysCore]]
 git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a"
@@ -2062,9 +2062,9 @@ version = "1.3.0"
 
 [[deps.StatsModels]]
 deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Printf", "REPL", "ShiftedArrays", "SparseArrays", "StatsBase", "StatsFuns", "Tables"]
-git-tree-sha1 = "06a230063087c11910e9bbd17ccbf5af792a27a4"
+git-tree-sha1 = "51cdf1afd9d78552e7a08536930d7abc3b288a5c"
 uuid = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
-version = "0.7.0"
+version = "0.7.1"
 
 [[deps.StatsPlots]]
 deps = ["AbstractFFTs", "Clustering", "DataStructures", "DataValues", "Distributions", "Interpolations", "KernelDensity", "LinearAlgebra", "MultivariateStats", "NaNMath", "Observables", "Plots", "RecipesBase", "RecipesPipeline", "Reexport", "StatsBase", "TableOperations", "Tables", "Widgets"]
@@ -2151,9 +2151,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
 
 [[deps.Tidier]]
 deps = ["Chain", "Cleaner", "DataFrames", "MacroTools", "Reexport", "ShiftedArrays", "Statistics"]
-git-tree-sha1 = "6c01fc23066480d998d3eabba0f999c46e88ed73"
+git-tree-sha1 = "107614fa8ae68ca12a193df719ec2438e86ab97b"
 uuid = "f0413319-3358-4bb0-8e7c-0c83523a93bd"
-version = "0.7.1"
+version = "0.7.4"
 
 [[deps.TiffImages]]
 deps = ["ColorTypes", "DataStructures", "DocStringExtensions", "FileIO", "FixedPointNumbers", "IndirectArrays", "Inflate", "Mmap", "OffsetArrays", "PkgVersion", "ProgressMeter", "UUIDs"]
@@ -2181,9 +2181,9 @@ version = "0.2.20"
 
 [[deps.TranscodingStreams]]
 deps = ["Random", "Test"]
-git-tree-sha1 = "94f38103c984f89cf77c402f2a68dbd870f8165f"
+git-tree-sha1 = "0b829474fed270a4b0ab07117dce9b9a2fa7581a"
 uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
-version = "0.9.11"
+version = "0.9.12"
 
 [[deps.Transducers]]
 deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"]
@@ -2232,9 +2232,9 @@ version = "0.4.1"
 
 [[deps.UnicodePlots]]
 deps = ["ColorSchemes", "ColorTypes", "Contour", "Crayons", "Dates", "LinearAlgebra", "MarchingCubes", "NaNMath", "Printf", "Requires", "SnoopPrecompile", "SparseArrays", "StaticArrays", "StatsBase"]
-git-tree-sha1 = "a5bcfc23e352f499a1a46f428d0d3d7fb9e4fc11"
+git-tree-sha1 = "2825e58f6ec3cab889dfa2c824f8d89b9f7ee731"
 uuid = "b8865327-cd53-5732-bb35-84acbb429228"
-version = "3.4.1"
+version = "3.5.1"
 
 [[deps.UnsafeAtomics]]
 git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
@@ -2438,15 +2438,15 @@ version = "1.2.12+3"
 
 [[deps.Zstd_jll]]
 deps = ["Artifacts", "JLLWrappers", "Libdl"]
-git-tree-sha1 = "c6edfe154ad7b313c01aceca188c05c835c67360"
+git-tree-sha1 = "49ce682769cd5de6c72dcf1b94ed7790cd08974c"
 uuid = "3161d3a3-bdf6-5164-811a-617609db77b4"
-version = "1.5.4+0"
+version = "1.5.5+0"
 
 [[deps.Zygote]]
 deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "Requires", "SnoopPrecompile", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"]
-git-tree-sha1 = "4df8f470806a45a8630ac8f597304821dc8e8838"
+git-tree-sha1 = "987ae5554ca90e837594a0f30325eeb5e7303d1e"
 uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.59"
+version = "0.6.60"
 
 [[deps.ZygoteRules]]
 deps = ["ChainRulesCore", "MacroTools"]
diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd
index 5e7f7aa615f8b9bf3f7c549c7c61f349961e3c7c..8a7f179f1dc24a1e55975659fc4c2a5b45d6ffa6 100644
--- a/notebooks/mnist.qmd
+++ b/notebooks/mnist.qmd
@@ -15,10 +15,10 @@ end
 
 ```{julia}
 # Data:
-n_obs = 10000
+n_obs = 1000
 counterfactual_data = load_mnist(n_obs)
+counterfactual_data.X = pre_process.(counterfactual_data.X)
 X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
-X = pre_process.(X)
 X = table(permutedims(X))
 labels = counterfactual_data.output_encoder.labels
 input_dim, n_obs = size(counterfactual_data.X)
@@ -32,18 +32,22 @@ First, let's create a couple of image classifier architectures:
 # Model parameters:
 epochs = 100
 batch_size = minimum([Int(round(n_obs/10)), 128])
-n_hidden = 32
+n_hidden = 50
 activation = Flux.swish
-# builder = MLJFlux.@builder Flux.Chain(
-#     Dense(n_in, n_hidden, activation),
-#     Dense(n_hidden, n_hidden, activation),
-#     Dense(n_hidden, n_hidden, activation),
-#     # BatchNorm(n_hidden, activation),
-#     # Dense(n_hidden, n_hidden),
-#     # BatchNorm(n_hidden, activation),
-#     Dense(n_hidden, n_out),
-# )
-builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.1, σ=activation)
+builder = MLJFlux.@builder Flux.Chain(
+
+    Dense(n_in, n_hidden, activation),
+    # Dense(n_hidden, n_hidden, activation),
+    # Dense(n_hidden, n_hidden, activation),
+    
+    # Dense(n_in, n_hidden),
+    # BatchNorm(n_hidden, activation),
+    # Dense(n_hidden, n_hidden),
+    # BatchNorm(n_hidden, activation),
+
+    Dense(n_hidden, n_out),
+)
+# builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.1, σ=activation)
 # builder = MLJFlux.MLP(
 #     hidden=(
 #         n_hidden,
@@ -52,7 +56,7 @@ builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.1, σ=activation)
 #     ), 
 #     σ=activation
 # )
-α = [1.0,1.0,1e-1]
+α = [1.0,1.0,1e-2]
 
 # Simple MLP:
 mlp = NeuralNetworkClassifier(
@@ -85,7 +89,7 @@ jem = JointEnergyClassifier(
 )
 
 # Deep Ensemble:
-mlp_ens = EnsembleModel(model=mlp, n=5)
+mlp_ens = EnsembleModel(model=mlp, n=50)
 ```
 
 ```{julia}
@@ -99,7 +103,7 @@ M = ECCCo.ConformalModel(mach.model, mach.fitresult)
 ```{julia}
 if mach.model.model isa JointEnergyModels.JointEnergyClassifier
     jem = mach.model.model.jem
-    n_iter = 500
+    n_iter = 200
     _w = 1500
     plts = []
     neach = 10
@@ -121,6 +125,7 @@ end
 
 ```{julia}
 test_data = load_mnist_test()
+test_data.X = pre_process.(test_data.X)
 f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data)
 println("F1 score (test): $(round(f1,digits=3))")
 ```
@@ -153,9 +158,9 @@ ce_jsma = generate_counterfactual(
 
 # ECCCo:
 λ=[0.0,1.0]
-temp=0.01
+temp=0.5
 
-# Generate counterfactual using ECCCo generator:
+# Generate counterfactual using CCE generator:
 generator = CCEGenerator(
     λ=λ, 
     temp=temp, 
@@ -168,7 +173,7 @@ ce_conformal = generate_counterfactual(
     converge_when=:generator_conditions,
 )
 
-# Generate counterfactual using ECCCo generator:
+# Generate counterfactual using CCE generator:
 generator = CCEGenerator(
     λ=λ, 
     temp=temp, 
@@ -191,7 +196,7 @@ p1 = Plots.plot(
 plts = [p1]
 
 ces = [ce_wachter, ce_conformal, ce_jsma, ce_conformal_jsma]
-_names = ["Wachter", "ECCCo", "JSMA", "ECCCo-JSMA"]
+_names = ["Wachter", "CCE", "JSMA", "CCE-JSMA"]
 for x in zip(ces, _names)
     ce, _name = (x[1],x[2])
     x = CounterfactualExplanations.counterfactual(ce)
@@ -221,6 +226,8 @@ factual = predict_label(M, counterfactual_data, x)[1]
 γ = 0.5
 T = 100
 
+η=0.1
+
 # Generate counterfactual using generic generator:
 generator = GenericGenerator(opt=Flux.Optimise.Adam(),)
 ce_wachter = generate_counterfactual(
@@ -229,7 +236,7 @@ ce_wachter = generate_counterfactual(
     initialization=:identity,
 )
 
-generator = GreedyGenerator(η=1.0)
+generator = GreedyGenerator(η=η)
 ce_jsma = generate_counterfactual(
     x, target, counterfactual_data, M, generator; 
     decision_threshold=γ, max_iter=T,
@@ -237,8 +244,8 @@ ce_jsma = generate_counterfactual(
 )
 
 # ECCCo:
-λ=[0.0,1.0,1.0]
-temp=0.01
+λ=[0.0,0.0,10.0]
+temp=0.5
 
 # Generate counterfactual using ECCCo generator:
 generator = ECCCoGenerator(
@@ -257,7 +264,7 @@ ce_conformal = generate_counterfactual(
 generator = ECCCoGenerator(
     λ=λ, 
     temp=temp, 
-    opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0),
+    opt=CounterfactualExplanations.Generators.JSMADescent(η=η),
 )
 ce_conformal_jsma = generate_counterfactual(
     x, target, counterfactual_data, M, generator; 
@@ -292,7 +299,30 @@ for x in zip(ces, _names)
 end
 plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
 display(plt)
-savefig(plt, joinpath(www_path, "cce_mnist.png"))
+savefig(plt, joinpath(www_path, "eccco_mnist.png"))
+```
+
+```{julia}
+if M.model.model isa JointEnergyModels.JointEnergyClassifier
+    jem = M.model.model.jem
+    n_iter = 200
+    _w = 1500
+    plts = []
+    neach = 10
+    for i in 1:10
+        x = jem.sampler(jem.chain, jem.sampling_rule; niter=n_iter, n_samples=neach, y=i)
+        plts_i = []
+        for j in 1:size(x, 2)
+            xj = x[:,j]
+            xj = reshape(xj, (n_digits, n_digits))
+            plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)]
+        end
+        plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))
+        plts = [plts..., plt]
+    end
+    plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1))
+    display(plt)
+end
 ```
 
 ## Benchmark
diff --git a/paper/paper.pdf b/paper/paper.pdf
index 2ba71c2968b35cca31816ccb45be28b071acf7fa..222cec3cdb4a1bfb1966a554996d478b458fc381 100644
Binary files a/paper/paper.pdf and b/paper/paper.pdf differ
diff --git a/paper/paper.tex b/paper/paper.tex
index bd18cf2a571a04e33ac4f7085bbdb62132d7190c..6ee1a43bf68b58fb0a7c4c743d1a728573ad951c 100644
--- a/paper/paper.tex
+++ b/paper/paper.tex
@@ -280,6 +280,13 @@ In order to generate prediction sets $C_{\theta}(f(\mathbf{Z}^\prime);\alpha)$ f
 
 \section{Experiments}
 
+\begin{itemize}
+  \item BatchNorm does not seem compatible with JEM
+  \item Coverage and temperature impacts CCE in somewhat unpredictable ways
+  \item It seems that models that are not explicitly trained for generative task, still learn it implictly
+  \item Batch size seems to impact quality of generated samples
+\end{itemize}
+
 \section{Discussion}
 
 Consistent with the findings in \citet{schut2021generating}, we have demonstrated that predictive uncertainty estimates can be leveraged to generate plausible counterfactuals. Interestingly, \citet{schut2021generating} point out that this finding --- as intuitive as it is --- may be linked to a positive connection between the generative task and predictive uncertainty quantification. In particular, \citet{grathwohl2020your} demonstrate that their proposed method for integrating the generative objective in training yields models that have improved predictive uncertainty quantification. Since neither \citet{schut2021generating} nor we have employed any surrogate generative models, our findings seem to indicate that the positive connection found in \citet{grathwohl2020your} is bidirectional.
diff --git a/src/ECCCo.jl b/src/ECCCo.jl
index 82c901d72e83b0b8ef6abf2ca43d5801c1403489..33135f6352d0a3f0a53b20dd5ac1f64e64ae6c70 100644
--- a/src/ECCCo.jl
+++ b/src/ECCCo.jl
@@ -9,6 +9,7 @@ include("losses.jl")
 include("generator.jl")
 include("sampling.jl")
 
-export ECCCoGenerator, EnergySampler, set_size_penalty, distance_from_energy
+export CCEGenerator, ECCCoGenerator, EnergySampler
+export set_size_penalty, distance_from_energy
 
 end
\ No newline at end of file
diff --git a/src/generator.jl b/src/generator.jl
index 2dc3e460c8bfed0846ecd56f90f6792fd99f3280..72952aa31de89bf7d1a8d297a8337abe97cf9a16 100644
--- a/src/generator.jl
+++ b/src/generator.jl
@@ -1,7 +1,12 @@
 using CounterfactualExplanations.Objectives
 
-"Constructor for `ECCCoGenerator`."
-function ECCCoGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], κ::Real=1.0, temp::Real=0.05, kwargs...)
+"Constructor for `CCEGenerator`."
+function CCEGenerator(; 
+    λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], 
+    κ::Real=1.0, 
+    temp::Real=0.5, 
+    kwargs...
+)
     function _set_size_penalty(ce::AbstractCounterfactualExplanation)
         return ECCCo.set_size_penalty(ce; κ=κ, temp=temp)
     end
@@ -11,7 +16,7 @@ function ECCCoGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1,
 end
 
 "Constructor for `ECECCCoGenerator`: Energy Constrained Conformal Counterfactual Explanation Generator."
-function ECECCCoGenerator(; 
+function ECCCoGenerator(; 
     λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0, 1.0], 
     κ::Real=1.0, 
     temp::Real=0.5, 
diff --git a/src/penalties.jl b/src/penalties.jl
index 81b953fe5884a9fec638a2db01df73edcdd3176a..95c96b4109cba3bf5f40d4c8f7ae03e44bfefceb 100644
--- a/src/penalties.jl
+++ b/src/penalties.jl
@@ -36,20 +36,19 @@ end
 
 function distance_from_energy(
     ce::AbstractCounterfactualExplanation;
-    n::Int=10000, from_buffer=true, agg=mean, kwargs...
+    n::Int=1, niter=200, from_buffer=true, agg=mean, kwargs...
 )
     conditional_samples = []
     ignore_derivatives() do
         _dict = ce.params
         if !(:energy_sampler ∈ collect(keys(_dict)))
-            _dict[:energy_sampler] = ECCCo.EnergySampler(ce; kwargs...)
+            _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=1000, kwargs...)
         end
         sampler = _dict[:energy_sampler]
         push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer))
     end
     x′ = CounterfactualExplanations.counterfactual(ce)
-    loss = map(eachslice(x′, dims=3)) do x
-        x = Matrix(x)
+    loss = map(eachslice(x′, dims=ndims(x′))) do x
         Δ = map(eachcol(conditional_samples[1])) do xsample
             norm(x - xsample)
         end
@@ -63,7 +62,7 @@ end
 
 function distance_from_targets(
     ce::AbstractCounterfactualExplanation;
-    n::Int=10000, agg=mean
+    n::Int=1000, agg=mean
 )
     target_idx = ce.data.output_encoder.labels .== ce.target
     target_samples = ce.data.X[:,target_idx] |>
diff --git a/src/sampling.jl b/src/sampling.jl
index 4f57fc408631300cee03fea0bc28a071635a3d32..50a170283a65f6e0ad64faae58084913898f0a95 100644
--- a/src/sampling.jl
+++ b/src/sampling.jl
@@ -42,42 +42,29 @@ function EnergySampler(
 )
 
     @assert y ∈ data.y_levels || y ∈ 1:length(data.y_levels)
-    K = length(data.y_levels)
-    𝒟x = Normal()
-    𝒟y = Categorical(ones(K) ./ K)
-    sampler = ConditionalSampler(𝒟x, 𝒟y)
+
+    if model.model.model isa JointEnergyClassifier
+        sampler = model.model.model.jem.sampler
+    else
+        K = length(data.y_levels)
+        input_size = size(selectdim(data.X, ndims(data.X), 1))
+        𝒟x = Uniform(extrema(data.X)...)
+        𝒟y = Categorical(ones(K) ./ K)
+        sampler = ConditionalSampler(𝒟x, 𝒟y; input_size=input_size)
+    end
     yidx = get_target_index(data.y_levels, y)
 
     # Initiate:
     energy_sampler = EnergySampler(model, data, sampler, opt, nothing, nothing)
 
     # Generate samples:
-    generate_samples!(energy_sampler, nsamples, yidx; niter=niter)
+    chain = model.model.model.jem.chain
+    rule = model.model.model.jem.sampling_rule
+    energy_sampler.sampler(chain, rule; niter=niter, n_samples=nsamples, y=yidx)
 
     return energy_sampler
 end
 
-"""
-    generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100)
-
-Generates `n` samples from `EnergySampler` for conditioning value `y`.
-"""
-function generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100)
-    X = e.sampler(e.model, e.opt, (size(e.data.X, 1), n); niter=niter, y=y)
-    X = X[:,map(x -> !any(isnan.(x)), eachcol(X))]
-    return X
-end
-
-"""
-    generate_samples!(e::EnergySampler, n::Int, y::Int; niter::Int=100)
-
-Generates `n` samples from `EnergySampler` for conditioning value `y`. Assigns samples and conditioning value to `EnergySampler`.
-"""
-function generate_samples!(e::EnergySampler, n::Int, y::Int; niter::Int=100)
-    e.buffer = generate_samples(e,n,y;niter=niter)
-    e.yidx = y
-end
-
 """
     EnergySampler(
         ce::CounterfactualExplanation;
@@ -105,12 +92,13 @@ end
 Overloads the `rand` method to randomly draw `n` samples from `EnergySampler`.
 """
 function Base.rand(sampler::EnergySampler, n::Int=100; from_buffer=true, niter::Int=100)
-    ntotal = size(sampler.buffer, 2)
+    ntotal = size(sampler.sampler.buffer)[end]
     idx = rand(1:ntotal, n)
     if from_buffer
-        X = sampler.buffer[:, idx]
+        X = sampler.sampler.buffer[:, idx]
     else
-        X = generate_samples(sampler, n, sampler.yidx; niter=niter)
+        chain = sampler.model.fitresult[1]
+        X = sampler.sampler(chain, sampler.opt; niter=niter, n_samples=n, y=sampler.yidx)
     end
     return X
 end
diff --git a/www/cce_mnist.png b/www/cce_mnist.png
index 55ff8d5650856abe43798dcd34a10da9de1c34f0..c96db716f561b5792813f3ace77acb7a51180819 100644
Binary files a/www/cce_mnist.png and b/www/cce_mnist.png differ
diff --git a/www/eccco_mnist.png b/www/eccco_mnist.png
new file mode 100644
index 0000000000000000000000000000000000000000..81cecbe4d0f947bf55ef08cd7f2db35f318337c2
Binary files /dev/null and b/www/eccco_mnist.png differ