diff --git a/Manifest.toml b/Manifest.toml
index 895027abfcdbe69cc3ec768bac44e3f7cce525a8..ebe111ccf005ad23530b9ccfd8b06694c9136bed 100644
--- a/Manifest.toml
+++ b/Manifest.toml
@@ -2,7 +2,7 @@
 
 julia_version = "1.8.5"
 manifest_format = "2.0"
-project_hash = "a3f5e7cf561325560c244a222fe34c51dfd04226"
+project_hash = "932009a0a4d8e913f882578b8740427ca967bce1"
 
 [[deps.AbstractFFTs]]
 deps = ["ChainRulesCore", "LinearAlgebra"]
@@ -49,12 +49,6 @@ git-tree-sha1 = "5ba6c757e8feccf03a1554dfaf3e26b3cfc7fd5e"
 uuid = "68821587-b530-5797-8361-c406ea357684"
 version = "3.5.1+1"
 
-[[deps.ArrayInterface]]
-deps = ["ArrayInterfaceCore", "Compat", "IfElse", "LinearAlgebra", "SnoopPrecompile", "Static"]
-git-tree-sha1 = "dedc16cbdd1d32bead4617d27572f582216ccf23"
-uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
-version = "6.0.25"
-
 [[deps.ArrayInterfaceCore]]
 deps = ["LinearAlgebra", "SnoopPrecompile", "SparseArrays", "SuiteSparse"]
 git-tree-sha1 = "e5f08b5689b1aad068e01751889f2f615c7db36d"
@@ -66,14 +60,9 @@ uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
 
 [[deps.BFloat16s]]
 deps = ["LinearAlgebra", "Printf", "Random", "Test"]
-git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072"
+git-tree-sha1 = "dbf84058d0a8cbbadee18d25cf606934b22d7c66"
 uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
-version = "0.2.0"
-
-[[deps.BSON]]
-git-tree-sha1 = "86e9781ac28f4e80e9b98f7f96eae21891332ac2"
-uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
-version = "0.3.6"
+version = "0.4.2"
 
 [[deps.BangBang]]
 deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"]
@@ -123,10 +112,34 @@ uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
 version = "0.10.9"
 
 [[deps.CUDA]]
-deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
-git-tree-sha1 = "6717cb9a3425ebb7b31ca4f832823615d175f64a"
+deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions"]
+git-tree-sha1 = "edff14c60784c8f7191a62a23b15a421185bc8a8"
 uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
-version = "3.13.1"
+version = "4.0.1"
+
+[[deps.CUDA_Driver_jll]]
+deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"]
+git-tree-sha1 = "75d7896d1ec079ef10d3aee8f3668c11354c03a1"
+uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc"
+version = "0.2.0+0"
+
+[[deps.CUDA_Runtime_Discovery]]
+deps = ["Libdl"]
+git-tree-sha1 = "58dd8ec29f54f08c04b052d2c2fa6760b4f4b3a4"
+uuid = "1af6417a-86b4-443c-805f-a4643ffb695f"
+version = "0.1.1"
+
+[[deps.CUDA_Runtime_jll]]
+deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"]
+git-tree-sha1 = "d3e6ccd30f84936c1a3a53d622d85d7d3f9b9486"
+uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
+version = "0.2.3+2"
+
+[[deps.CUDNN_jll]]
+deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"]
+git-tree-sha1 = "57011df4fce448828165e566af9befa2ea94350a"
+uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645"
+version = "8.6.0+3"
 
 [[deps.Cairo_jll]]
 deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"]
@@ -439,10 +452,10 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
 version = "0.8.4"
 
 [[deps.Flux]]
-deps = ["Adapt", "ArrayInterface", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Test", "Zygote"]
-git-tree-sha1 = "96dc065bf4a998e8adeebc0ff1302902b6e59362"
+deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Zygote"]
+git-tree-sha1 = "c258e51850ac1fdc465f62380a61995d4a66d603"
 uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
-version = "0.13.4"
+version = "0.13.12"
 
 [[deps.FoldsThreads]]
 deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"]
@@ -486,9 +499,10 @@ uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
 version = "1.1.3"
 
 [[deps.Functors]]
-git-tree-sha1 = "223fffa49ca0ff9ce4f875be001ffe173b2b7de4"
+deps = ["LinearAlgebra"]
+git-tree-sha1 = "61fa9cf802d35fe1b5b8ea9fbaac4b8f020d19b1"
 uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
-version = "0.2.8"
+version = "0.4.2"
 
 [[deps.Future]]
 deps = ["Random"]
@@ -602,14 +616,9 @@ version = "0.3.11"
 
 [[deps.IRTools]]
 deps = ["InteractiveUtils", "MacroTools", "Test"]
-git-tree-sha1 = "2e99184fca5eb6f075944b04c22edec29beb4778"
+git-tree-sha1 = "2af2fe19f0d5799311a6491267a14817ad9fbd20"
 uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
-version = "0.4.7"
-
-[[deps.IfElse]]
-git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1"
-uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
-version = "0.1.1"
+version = "0.4.8"
 
 [[deps.ImageBase]]
 deps = ["ImageCore", "Reexport"]
@@ -910,10 +919,10 @@ uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
 version = "0.21.5"
 
 [[deps.MLJFlux]]
-deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"]
-git-tree-sha1 = "2ecdce4dd9214789ee1796103d29eaee7619ebd0"
+deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "ProgressMeter", "Random", "Statistics", "Tables"]
+git-tree-sha1 = "a47257705ebca405a25320b111345a978925bcd5"
 uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845"
-version = "0.2.9"
+version = "0.2.7"
 
 [[deps.MLJModelInterface]]
 deps = ["Random", "ScientificTypesBase", "StatisticalTraits"]
@@ -923,9 +932,9 @@ version = "1.8.0"
 
 [[deps.MLJModels]]
 deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
-git-tree-sha1 = "08203fc87a7f992cee24e7a1b2353e594c73c41c"
+git-tree-sha1 = "638c84dea26eb91a538afb52cef11d77556e6807"
 uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
-version = "0.16.2"
+version = "0.16.3"
 
 [[deps.MLStyle]]
 git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8"
@@ -933,10 +942,10 @@ uuid = "d8e11817-5142-5d16-987a-aa16d5891078"
 version = "0.4.17"
 
 [[deps.MLUtils]]
-deps = ["ChainRulesCore", "DelimitedFiles", "FLoops", "FoldsThreads", "Random", "ShowCases", "Statistics", "StatsBase", "Transducers"]
-git-tree-sha1 = "824e9dfc7509cab1ec73ba77b55a916bb2905e26"
+deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "FoldsThreads", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"]
+git-tree-sha1 = "f69cdbb5b7c630c02481d81d50eac43697084fe0"
 uuid = "f1d291b0-491e-4a28-83b9-f70985020b54"
-version = "0.2.11"
+version = "0.4.1"
 
 [[deps.MacroTools]]
 deps = ["Markdown", "Random"]
@@ -975,12 +984,6 @@ git-tree-sha1 = "c13304c81eec1ed3af7fc20e75fb6b26092a1102"
 uuid = "442fdcdd-2543-5da2-b0f3-8c86c306513e"
 version = "0.3.2"
 
-[[deps.Metalhead]]
-deps = ["Artifacts", "BSON", "Flux", "Functors", "LazyArtifacts", "MLUtils", "NNlib", "Random", "Statistics"]
-git-tree-sha1 = "a8513152030f7210ccc0b871e03d60c9b13ed0b1"
-uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
-version = "0.7.3"
-
 [[deps.MicroCollections]]
 deps = ["BangBang", "InitialValues", "Setfield"]
 git-tree-sha1 = "4d5917a26ca33c66c8e5ca3247bd163624d35493"
@@ -1037,10 +1040,10 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
 version = "0.8.18"
 
 [[deps.NNlibCUDA]]
-deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"]
-git-tree-sha1 = "b05a082b08a3af0e5c576883bc6dfb6513e7e478"
+deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics", "cuDNN"]
+git-tree-sha1 = "f94a9684394ff0d325cc12b06da7032d8be01aaf"
 uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d"
-version = "0.2.6"
+version = "0.2.7"
 
 [[deps.NPZ]]
 deps = ["FileIO", "ZipFile"]
@@ -1099,6 +1102,12 @@ git-tree-sha1 = "887579a3eb005446d514ab7aeac5d1d027658b8f"
 uuid = "e7412a2a-1a6e-54c0-be00-318e2571c051"
 version = "1.3.5+1"
 
+[[deps.OneHotArrays]]
+deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"]
+git-tree-sha1 = "f511fca956ed9e70b80cd3417bb8c2dde4b68644"
+uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
+version = "0.2.3"
+
 [[deps.OpenBLAS_jll]]
 deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
 uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
@@ -1135,9 +1144,9 @@ version = "2.0.2"
 
 [[deps.Optimisers]]
 deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"]
-git-tree-sha1 = "1ef34738708e3f31994b52693286dabcb3d29f6b"
+git-tree-sha1 = "e657acef119cc0de2a8c0762666d3b64727b053b"
 uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
-version = "0.2.9"
+version = "0.2.14"
 
 [[deps.Opus_jll]]
 deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -1453,12 +1462,6 @@ git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c"
 uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15"
 version = "0.1.1"
 
-[[deps.Static]]
-deps = ["IfElse"]
-git-tree-sha1 = "c35b107b61e7f34fa3f124026f2a9be97dea9e1c"
-uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
-version = "0.8.3"
-
 [[deps.StaticArrays]]
 deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"]
 git-tree-sha1 = "cee507162ecbb677450f20058ca83bd559b6b752"
@@ -1822,6 +1825,12 @@ git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0"
 uuid = "700de1a5-db45-46bc-99cf-38207098b444"
 version = "0.2.2"
 
+[[deps.cuDNN]]
+deps = ["CEnum", "CUDA", "CUDNN_jll"]
+git-tree-sha1 = "c0ffcb38d1e8c0bbcd3dab2559cf9c369130b2f2"
+uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
+version = "1.0.1"
+
 [[deps.fzf_jll]]
 deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
 git-tree-sha1 = "868e669ccb12ba16eaf50cb2957ee2ff61261c56"
diff --git a/Project.toml b/Project.toml
index 5a1cf95cc1483a07831447524b54b82249b33491..6cbf43c89bdb9ec787b9fc4ff8f92313aee55be9 100644
--- a/Project.toml
+++ b/Project.toml
@@ -10,6 +10,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
 MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
+MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
 MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
 Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
 Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"