From 5db91e75281cbaf2f113393868cb8c59170f4b2f Mon Sep 17 00:00:00 2001
From: pat-alt <altmeyerpat@gmail.com>
Date: Thu, 14 Sep 2023 15:24:43 +0200
Subject: [PATCH] added resnet code

---
 Project.toml                            |  3 +-
 experiments/Manifest.toml               | 73 ++++++++++++++-----------
 experiments/Project.toml                |  1 +
 experiments/models/additional_models.jl |  2 +-
 4 files changed, 44 insertions(+), 35 deletions(-)

diff --git a/Project.toml b/Project.toml
index 62e55cb1..2f8d9671 100644
--- a/Project.toml
+++ b/Project.toml
@@ -23,6 +23,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
 StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
 Term = "22787eb5-b846-44ae-b979-8e399b8463ab"
+cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
 
 [compat]
 CategoricalArrays = "0.10.8"
@@ -35,7 +36,7 @@ Flux = "0.13 - 0.14"
 JointEnergyModels = "0.1.1"
 MLJBase = "0.21.13"
 MLJEnsembles = "0.3.3"
-MLJFlux = "0.2.10"
+MLJFlux = "0.2.10, 0.3"
 MLJModelInterface = "1.8.0"
 MLUtils = "0.4.3"
 Parameters = "0.12.3"
diff --git a/experiments/Manifest.toml b/experiments/Manifest.toml
index 259ecde9..e80158c7 100644
--- a/experiments/Manifest.toml
+++ b/experiments/Manifest.toml
@@ -2,7 +2,7 @@
 
 julia_version = "1.9.3"
 manifest_format = "2.0"
-project_hash = "f97f8534bd2363738071587ac7055fab77fc36f4"
+project_hash = "3525ff74c1eae065ab039bda0ac1230901950a3b"
 
 [[deps.ARFFFiles]]
 deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"]
@@ -279,9 +279,9 @@ weakdeps = ["JSON", "RecipesBase", "SentinelArrays", "StructTypes"]
 
 [[deps.CategoricalDistributions]]
 deps = ["CategoricalArrays", "Distributions", "Missings", "OrderedCollections", "Random", "ScientificTypes"]
-git-tree-sha1 = "da68989f027dcefa74d44a452c9e36af9730a70d"
+git-tree-sha1 = "ed760a4fde49997ff9360a780abe6e20175162aa"
 uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e"
-version = "0.1.10"
+version = "0.1.11"
 
     [deps.CategoricalDistributions.extensions]
     UnivariateFiniteDisplayExt = "UnicodePlots"
@@ -416,9 +416,9 @@ version = "2.2.1"
 
 [[deps.ConformalPrediction]]
 deps = ["CategoricalArrays", "ChainRules", "ComputationalResources", "Flux", "LazyArtifacts", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "NaturalSort", "ProgressMeter", "Random", "Serialization", "StatsBase", "Tables"]
-git-tree-sha1 = "dc5de97a5304935398ab5aee3d86a91db2805441"
+git-tree-sha1 = "5ae9a307c350e2afeb411b2f4cc89ac2efbe9823"
 uuid = "98bfc277-1877-43dc-819b-a3e38c30242f"
-version = "0.1.11"
+version = "0.1.12"
 
 [[deps.ConstructionBase]]
 deps = ["LinearAlgebra"]
@@ -584,7 +584,7 @@ uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
 version = "0.6.8"
 
 [[deps.ECCCo]]
-deps = ["CategoricalArrays", "ChainRules", "ConformalPrediction", "CounterfactualExplanations", "Distances", "Distributions", "Flux", "JointEnergyModels", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "Parameters", "Random", "Statistics", "StatsBase", "Term"]
+deps = ["CategoricalArrays", "ChainRules", "ConformalPrediction", "CounterfactualExplanations", "Distances", "Distributions", "Flux", "JointEnergyModels", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "Parameters", "Random", "Statistics", "StatsBase", "Term", "cuDNN"]
 path = ".."
 uuid = "0232c203-4013-4b0d-ad96-43e3e11ac3bf"
 version = "0.1.0"
@@ -730,18 +730,22 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
 version = "0.8.4"
 
 [[deps.Flux]]
-deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote", "cuDNN"]
-git-tree-sha1 = "3e2c3704c2173ab4b1935362384ca878b53d4c34"
+deps = ["Adapt", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"]
+git-tree-sha1 = "7ffd73e8ca363f80367123aab0c5d0edabab4e60"
 uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
-version = "0.13.17"
+version = "0.14.5"
 
     [deps.Flux.extensions]
-    AMDGPUExt = "AMDGPU"
+    FluxAMDGPUExt = "AMDGPU"
+    FluxCUDAExt = "CUDA"
+    FluxCUDAcuDNNExt = ["CUDA", "cuDNN"]
     FluxMetalExt = "Metal"
 
     [deps.Flux.weakdeps]
     AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
+    CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
     Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
+    cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
 
 [[deps.Fontconfig_jll]]
 deps = ["Artifacts", "Bzip2_jll", "Expat_jll", "FreeType2_jll", "JLLWrappers", "Libdl", "Libuuid_jll", "Pkg", "Zlib_jll"]
@@ -1092,9 +1096,9 @@ version = "1.13.2"
 
 [[deps.JointEnergyModels]]
 deps = ["CategoricalArrays", "ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "MLUtils", "ProgressMeter", "Random", "StatsBase", "Tables", "Zygote"]
-git-tree-sha1 = "ae360013ba7ccb732c59bdbbaa53aaf96f09e2d7"
+git-tree-sha1 = "1b4a8ae085fa69edf41b66bd6a18bc7cf37465c2"
 uuid = "48c56d24-211d-4463-bbc0-7a701b291131"
-version = "0.1.2"
+version = "0.1.3"
 
 [[deps.JpegTurbo]]
 deps = ["CEnum", "FileIO", "ImageCore", "JpegTurbo_jll", "TOML"]
@@ -1270,9 +1274,9 @@ version = "0.1.12"
 
 [[deps.Loess]]
 deps = ["Distances", "LinearAlgebra", "Statistics", "StatsAPI"]
-git-tree-sha1 = "9403bfea9bc9acc9c7d803a1b39d0a668ed40f03"
+git-tree-sha1 = "a113a8be4c6d0c64e217b472fb6e61c760eb4022"
 uuid = "4345ca2d-374a-55d4-8d30-97f9976e7612"
-version = "0.6.2"
+version = "0.6.3"
 
 [[deps.LogExpFunctions]]
 deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
@@ -1365,9 +1369,9 @@ version = "0.1.1"
 
 [[deps.MLJFlux]]
 deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"]
-git-tree-sha1 = "b27c3b96cc2a602a1e91eba36b8ca3d796f30ae0"
+git-tree-sha1 = "40f9e99a6770bc795f70f1908316e1491488a7b7"
 uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845"
-version = "0.2.10"
+version = "0.3.1"
 
 [[deps.MLJIteration]]
 deps = ["IterationControl", "MLJBase", "Random", "Serialization"]
@@ -1492,10 +1496,10 @@ uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
 version = "2.28.2+0"
 
 [[deps.Metalhead]]
-deps = ["Artifacts", "BSON", "Flux", "Functors", "LazyArtifacts", "MLUtils", "NNlib", "Random", "Statistics"]
-git-tree-sha1 = "0e95f91cc5f23610f8f270d7397f307b21e19d2b"
+deps = ["Artifacts", "BSON", "CUDA", "ChainRulesCore", "Flux", "Functors", "JLD2", "LazyArtifacts", "MLUtils", "NNlib", "PartialFunctions", "Random", "Statistics"]
+git-tree-sha1 = "c093734078e92a4edcf54e850af68ef8cc2c9e03"
 uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
-version = "0.7.4"
+version = "0.8.2"
 
 [[deps.MicroCollections]]
 deps = ["BangBang", "InitialValues", "Setfield"]
@@ -1557,21 +1561,19 @@ version = "7.8.3"
 
 [[deps.NNlib]]
 deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"]
-git-tree-sha1 = "72240e3f5ca031937bd536182cb2c031da5f46dd"
+git-tree-sha1 = "3b29fafcdfa66d6673306cf116a2dc243933e2c5"
 uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
-version = "0.8.21"
+version = "0.9.5"
 
     [deps.NNlib.extensions]
     NNlibAMDGPUExt = "AMDGPU"
+    NNlibCUDACUDNNExt = ["CUDA", "cuDNN"]
+    NNlibCUDAExt = "CUDA"
 
     [deps.NNlib.weakdeps]
     AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
-
-[[deps.NNlibCUDA]]
-deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics", "cuDNN"]
-git-tree-sha1 = "f94a9684394ff0d325cc12b06da7032d8be01aaf"
-uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d"
-version = "0.2.7"
+    CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
+    cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
 
 [[deps.NPZ]]
 deps = ["FileIO", "ZipFile"]
@@ -1707,9 +1709,9 @@ version = "1.7.7"
 
 [[deps.Optimisers]]
 deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"]
-git-tree-sha1 = "c1fc26bab5df929a5172f296f25d7d08688fd25b"
+git-tree-sha1 = "34205b1204cc83c43cd9cfe53ffbd3b310f6e8c5"
 uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
-version = "0.2.20"
+version = "0.3.1"
 
 [[deps.Opus_jll]]
 deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
@@ -1781,6 +1783,11 @@ git-tree-sha1 = "716e24b21538abc91f6205fd1d8363f39b442851"
 uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
 version = "2.7.2"
 
+[[deps.PartialFunctions]]
+git-tree-sha1 = "b3901ea034cfd8aae57a2fa0dde0b0ea18bad1cb"
+uuid = "570af359-4316-4cb7-8c74-252c00c2016b"
+version = "1.1.1"
+
 [[deps.PeriodicTable]]
 deps = ["Base64", "Test", "Unitful"]
 git-tree-sha1 = "9a9731f346797126271405971dfdf4709947718b"
@@ -1829,9 +1836,9 @@ version = "0.1.2"
 
 [[deps.Polynomials]]
 deps = ["LinearAlgebra", "RecipesBase", "Setfield", "SparseArrays"]
-git-tree-sha1 = "b5d848e70cdf62f6896d29494c2a69ce4610ea8d"
+git-tree-sha1 = "ea78a2764f31715093de7ab495e12c0187f231d1"
 uuid = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
-version = "4.0.3"
+version = "4.0.4"
 
     [deps.Polynomials.extensions]
     PolynomialsChainRulesCoreExt = "ChainRulesCore"
@@ -2382,9 +2389,9 @@ uuid = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
 version = "0.3.5"
 
 [[deps.TupleTools]]
-git-tree-sha1 = "3c712976c47707ff893cf6ba4354aa14db1d8938"
+git-tree-sha1 = "f4dfd6fc59551c4e70fbcd75ee36ef602b0a8f29"
 uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
-version = "1.3.0"
+version = "1.4.2"
 
 [[deps.URIs]]
 git-tree-sha1 = "b7a5e99f24892b6824a954199a45e9ffcc1c70f0"
diff --git a/experiments/Project.toml b/experiments/Project.toml
index da0565ef..789c7c9b 100644
--- a/experiments/Project.toml
+++ b/experiments/Project.toml
@@ -21,4 +21,5 @@ Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
 Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
 TidierData = "fe2206b3-d496-4ee9-a338-6a095c4ece80"
 TidierPlots = "337ecbd1-5042-4e2a-ae6f-ca776f97570a"
+cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
 ghr_jll = "07c12ed4-43bc-5495-8a2a-d5838ef8d533"
diff --git a/experiments/models/additional_models.jl b/experiments/models/additional_models.jl
index 9691f60f..d6e8f814 100644
--- a/experiments/models/additional_models.jl
+++ b/experiments/models/additional_models.jl
@@ -62,7 +62,7 @@ Overloads the MLJFlux build function for a LeNet-like convolutional neural netwo
 """
 function MLJFlux.build(b::ResNetBuilder, rng, n_in, n_out)
     _n_in = Int(sqrt(n_in))
-    front = Metalhead.ResNet(18; pretrain=true, inchannels=1)
+    front = Metalhead.ResNet(18; inchannels=1)
     d = Flux.outputsize(front, (_n_in, _n_in, 1, 1)) |> first
     back = Flux.Chain(
         Dense(d, 120, relu),
-- 
GitLab