From 5db91e75281cbaf2f113393868cb8c59170f4b2f Mon Sep 17 00:00:00 2001 From: pat-alt <altmeyerpat@gmail.com> Date: Thu, 14 Sep 2023 15:24:43 +0200 Subject: [PATCH] added resnet code --- Project.toml | 3 +- experiments/Manifest.toml | 73 ++++++++++++++----------- experiments/Project.toml | 1 + experiments/models/additional_models.jl | 2 +- 4 files changed, 44 insertions(+), 35 deletions(-) diff --git a/Project.toml b/Project.toml index 62e55cb1..2f8d9671 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Term = "22787eb5-b846-44ae-b979-8e399b8463ab" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] CategoricalArrays = "0.10.8" @@ -35,7 +36,7 @@ Flux = "0.13 - 0.14" JointEnergyModels = "0.1.1" MLJBase = "0.21.13" MLJEnsembles = "0.3.3" -MLJFlux = "0.2.10" +MLJFlux = "0.2.10, 0.3" MLJModelInterface = "1.8.0" MLUtils = "0.4.3" Parameters = "0.12.3" diff --git a/experiments/Manifest.toml b/experiments/Manifest.toml index 259ecde9..e80158c7 100644 --- a/experiments/Manifest.toml +++ b/experiments/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.9.3" manifest_format = "2.0" -project_hash = "f97f8534bd2363738071587ac7055fab77fc36f4" +project_hash = "3525ff74c1eae065ab039bda0ac1230901950a3b" [[deps.ARFFFiles]] deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"] @@ -279,9 +279,9 @@ weakdeps = ["JSON", "RecipesBase", "SentinelArrays", "StructTypes"] [[deps.CategoricalDistributions]] deps = ["CategoricalArrays", "Distributions", "Missings", "OrderedCollections", "Random", "ScientificTypes"] -git-tree-sha1 = "da68989f027dcefa74d44a452c9e36af9730a70d" +git-tree-sha1 = "ed760a4fde49997ff9360a780abe6e20175162aa" uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e" -version = "0.1.10" +version = "0.1.11" [deps.CategoricalDistributions.extensions] UnivariateFiniteDisplayExt = "UnicodePlots" @@ -416,9 +416,9 @@ version = "2.2.1" [[deps.ConformalPrediction]] deps = ["CategoricalArrays", "ChainRules", "ComputationalResources", "Flux", "LazyArtifacts", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "NaturalSort", "ProgressMeter", "Random", "Serialization", "StatsBase", "Tables"] -git-tree-sha1 = "dc5de97a5304935398ab5aee3d86a91db2805441" +git-tree-sha1 = "5ae9a307c350e2afeb411b2f4cc89ac2efbe9823" uuid = "98bfc277-1877-43dc-819b-a3e38c30242f" -version = "0.1.11" +version = "0.1.12" [[deps.ConstructionBase]] deps = ["LinearAlgebra"] @@ -584,7 +584,7 @@ uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" version = "0.6.8" [[deps.ECCCo]] -deps = ["CategoricalArrays", "ChainRules", "ConformalPrediction", "CounterfactualExplanations", "Distances", "Distributions", "Flux", "JointEnergyModels", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "Parameters", "Random", "Statistics", "StatsBase", "Term"] +deps = ["CategoricalArrays", "ChainRules", "ConformalPrediction", "CounterfactualExplanations", "Distances", "Distributions", "Flux", "JointEnergyModels", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "Parameters", "Random", "Statistics", "StatsBase", "Term", "cuDNN"] path = ".." uuid = "0232c203-4013-4b0d-ad96-43e3e11ac3bf" version = "0.1.0" @@ -730,18 +730,22 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" version = "0.8.4" [[deps.Flux]] -deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote", "cuDNN"] -git-tree-sha1 = "3e2c3704c2173ab4b1935362384ca878b53d4c34" +deps = ["Adapt", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] +git-tree-sha1 = "7ffd73e8ca363f80367123aab0c5d0edabab4e60" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.13.17" +version = "0.14.5" [deps.Flux.extensions] - AMDGPUExt = "AMDGPU" + FluxAMDGPUExt = "AMDGPU" + FluxCUDAExt = "CUDA" + FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] FluxMetalExt = "Metal" [deps.Flux.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.Fontconfig_jll]] deps = ["Artifacts", "Bzip2_jll", "Expat_jll", "FreeType2_jll", "JLLWrappers", "Libdl", "Libuuid_jll", "Pkg", "Zlib_jll"] @@ -1092,9 +1096,9 @@ version = "1.13.2" [[deps.JointEnergyModels]] deps = ["CategoricalArrays", "ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "MLUtils", "ProgressMeter", "Random", "StatsBase", "Tables", "Zygote"] -git-tree-sha1 = "ae360013ba7ccb732c59bdbbaa53aaf96f09e2d7" +git-tree-sha1 = "1b4a8ae085fa69edf41b66bd6a18bc7cf37465c2" uuid = "48c56d24-211d-4463-bbc0-7a701b291131" -version = "0.1.2" +version = "0.1.3" [[deps.JpegTurbo]] deps = ["CEnum", "FileIO", "ImageCore", "JpegTurbo_jll", "TOML"] @@ -1270,9 +1274,9 @@ version = "0.1.12" [[deps.Loess]] deps = ["Distances", "LinearAlgebra", "Statistics", "StatsAPI"] -git-tree-sha1 = "9403bfea9bc9acc9c7d803a1b39d0a668ed40f03" +git-tree-sha1 = "a113a8be4c6d0c64e217b472fb6e61c760eb4022" uuid = "4345ca2d-374a-55d4-8d30-97f9976e7612" -version = "0.6.2" +version = "0.6.3" [[deps.LogExpFunctions]] deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] @@ -1365,9 +1369,9 @@ version = "0.1.1" [[deps.MLJFlux]] deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"] -git-tree-sha1 = "b27c3b96cc2a602a1e91eba36b8ca3d796f30ae0" +git-tree-sha1 = "40f9e99a6770bc795f70f1908316e1491488a7b7" uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845" -version = "0.2.10" +version = "0.3.1" [[deps.MLJIteration]] deps = ["IterationControl", "MLJBase", "Random", "Serialization"] @@ -1492,10 +1496,10 @@ uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" version = "2.28.2+0" [[deps.Metalhead]] -deps = ["Artifacts", "BSON", "Flux", "Functors", "LazyArtifacts", "MLUtils", "NNlib", "Random", "Statistics"] -git-tree-sha1 = "0e95f91cc5f23610f8f270d7397f307b21e19d2b" +deps = ["Artifacts", "BSON", "CUDA", "ChainRulesCore", "Flux", "Functors", "JLD2", "LazyArtifacts", "MLUtils", "NNlib", "PartialFunctions", "Random", "Statistics"] +git-tree-sha1 = "c093734078e92a4edcf54e850af68ef8cc2c9e03" uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" -version = "0.7.4" +version = "0.8.2" [[deps.MicroCollections]] deps = ["BangBang", "InitialValues", "Setfield"] @@ -1557,21 +1561,19 @@ version = "7.8.3" [[deps.NNlib]] deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "72240e3f5ca031937bd536182cb2c031da5f46dd" +git-tree-sha1 = "3b29fafcdfa66d6673306cf116a2dc243933e2c5" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.8.21" +version = "0.9.5" [deps.NNlib.extensions] NNlibAMDGPUExt = "AMDGPU" + NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] + NNlibCUDAExt = "CUDA" [deps.NNlib.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - -[[deps.NNlibCUDA]] -deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics", "cuDNN"] -git-tree-sha1 = "f94a9684394ff0d325cc12b06da7032d8be01aaf" -uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" -version = "0.2.7" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.NPZ]] deps = ["FileIO", "ZipFile"] @@ -1707,9 +1709,9 @@ version = "1.7.7" [[deps.Optimisers]] deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "c1fc26bab5df929a5172f296f25d7d08688fd25b" +git-tree-sha1 = "34205b1204cc83c43cd9cfe53ffbd3b310f6e8c5" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.2.20" +version = "0.3.1" [[deps.Opus_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -1781,6 +1783,11 @@ git-tree-sha1 = "716e24b21538abc91f6205fd1d8363f39b442851" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" version = "2.7.2" +[[deps.PartialFunctions]] +git-tree-sha1 = "b3901ea034cfd8aae57a2fa0dde0b0ea18bad1cb" +uuid = "570af359-4316-4cb7-8c74-252c00c2016b" +version = "1.1.1" + [[deps.PeriodicTable]] deps = ["Base64", "Test", "Unitful"] git-tree-sha1 = "9a9731f346797126271405971dfdf4709947718b" @@ -1829,9 +1836,9 @@ version = "0.1.2" [[deps.Polynomials]] deps = ["LinearAlgebra", "RecipesBase", "Setfield", "SparseArrays"] -git-tree-sha1 = "b5d848e70cdf62f6896d29494c2a69ce4610ea8d" +git-tree-sha1 = "ea78a2764f31715093de7ab495e12c0187f231d1" uuid = "f27b6e38-b328-58d1-80ce-0feddd5e7a45" -version = "4.0.3" +version = "4.0.4" [deps.Polynomials.extensions] PolynomialsChainRulesCoreExt = "ChainRulesCore" @@ -2382,9 +2389,9 @@ uuid = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" version = "0.3.5" [[deps.TupleTools]] -git-tree-sha1 = "3c712976c47707ff893cf6ba4354aa14db1d8938" +git-tree-sha1 = "f4dfd6fc59551c4e70fbcd75ee36ef602b0a8f29" uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" -version = "1.3.0" +version = "1.4.2" [[deps.URIs]] git-tree-sha1 = "b7a5e99f24892b6824a954199a45e9ffcc1c70f0" diff --git a/experiments/Project.toml b/experiments/Project.toml index da0565ef..789c7c9b 100644 --- a/experiments/Project.toml +++ b/experiments/Project.toml @@ -21,4 +21,5 @@ Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" TidierData = "fe2206b3-d496-4ee9-a338-6a095c4ece80" TidierPlots = "337ecbd1-5042-4e2a-ae6f-ca776f97570a" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" ghr_jll = "07c12ed4-43bc-5495-8a2a-d5838ef8d533" diff --git a/experiments/models/additional_models.jl b/experiments/models/additional_models.jl index 9691f60f..d6e8f814 100644 --- a/experiments/models/additional_models.jl +++ b/experiments/models/additional_models.jl @@ -62,7 +62,7 @@ Overloads the MLJFlux build function for a LeNet-like convolutional neural netwo """ function MLJFlux.build(b::ResNetBuilder, rng, n_in, n_out) _n_in = Int(sqrt(n_in)) - front = Metalhead.ResNet(18; pretrain=true, inchannels=1) + front = Metalhead.ResNet(18; inchannels=1) d = Flux.outputsize(front, (_n_in, _n_in, 1, 1)) |> first back = Flux.Chain( Dense(d, 120, relu), -- GitLab