Skip to content
Snippets Groups Projects
Commit 1ecefaff authored by Pat Alt's avatar Pat Alt
Browse files

uf

parent 2704f5be
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
......@@ -2,7 +2,7 @@
julia_version = "1.9.3"
manifest_format = "2.0"
project_hash = "610dd01ba3ad744c63925263e3388bfa9abf58ab"
project_hash = "f97f8534bd2363738071587ac7055fab77fc36f4"
[[deps.ARFFFiles]]
deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"]
......
......@@ -17,6 +17,7 @@ MLJEnsembles = "50ed68f4-41fd-4504-931a-ed422449fee0"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
TidierData = "fe2206b3-d496-4ee9-a338-6a095c4ece80"
TidierPlots = "337ecbd1-5042-4e2a-ae6f-ca776f97570a"
......
......@@ -48,3 +48,34 @@ Builds a LeNet-like convolutional neural network.
"""
lenet5(builder=LeNetBuilder(5, 6, 16); kwargs...) = NeuralNetworkClassifier(builder=builder; acceleration=CUDALibs(), kwargs...)
"""
ResNetBuilder
MLJFlux builder for a ResNet.
"""
mutable struct ResNetBuilder end
"""
MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out)
Overloads the MLJFlux build function for a LeNet-like convolutional neural network.
"""
function MLJFlux.build(b::ResNetBuilder, rng, n_in, n_out)
_n_in = Int(sqrt(n_in))
front = Metalhead.ResNet(18; pretrain=true, inchannels=1)
d = Flux.outputsize(front, (_n_in, _n_in, 1, 1)) |> first
back = Flux.Chain(
Dense(d, 120, relu),
Dense(120, 84, relu),
Dense(84, n_out),
)
chain = Flux.Chain(ECCCo.ToConv(_n_in), front, back)
return chain
end
"""
lenet5(builder=LeNetBuilder(5, 6, 16); kwargs...)
Builds a LeNet-like convolutional neural network.
"""
resnet18(builder=ResNetBuilder(); kwargs...) = NeuralNetworkClassifier(builder=builder; acceleration=CUDALibs(), kwargs...)
......@@ -18,6 +18,7 @@ using Flux
using JointEnergyModels
using LazyArtifacts
using Logging
using Metalhead
using MLJBase: multiclass_f1score, accuracy, multiclass_precision, table, machine, fit!
using MLJEnsembles
using MLJFlux
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment