From a1aff09648ded0f8641998ab47ad0f943e3d20ed Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Wed, 6 Sep 2023 21:14:51 +0200 Subject: [PATCH] more things --- experiments/mnist.jl | 2 +- experiments/models/additional_models.jl | 3 +-- src/utils.jl | 19 +++++++++++++++++++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/experiments/mnist.jl b/experiments/mnist.jl index a3fbbaae..627e1d25 100644 --- a/experiments/mnist.jl +++ b/experiments/mnist.jl @@ -13,7 +13,7 @@ test_data = load_mnist_test() # Additional models: add_models = Dict( - :lenet5 => lenet5, + "LeNet-5" => lenet5, ) # Default builder: diff --git a/experiments/models/additional_models.jl b/experiments/models/additional_models.jl index a5b354d6..f6ed7451 100644 --- a/experiments/models/additional_models.jl +++ b/experiments/models/additional_models.jl @@ -21,7 +21,6 @@ function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out) k, c1, c2 = b.filter_size, b.channels1, b.channels2 mod(k, 2) == 1 || error("`filter_size` must be odd. ") p = div(k - 1, 2) # padding to preserve image size on convolution: - preproc(x) = reshape(x, (_n_in, _n_in, 1, :)) # Model: front = Flux.Chain( @@ -37,7 +36,7 @@ function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out) Dense(120, 84, relu), Dense(84, n_out), ) - chain = Flux.Chain(preproc, front, back) + chain = Flux.Chain(ECCCo.ToConv(_n_in), front, back) return chain end diff --git a/src/utils.jl b/src/utils.jl index 968fb26e..87c4d0a7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,5 +1,24 @@ +""" + pre_process(x; noise=0.03f0) + +Helper function to add tiny noise to inputs. +""" function pre_process(x; noise::Float32=0.03f0) ϵ = Float32.(randn(size(x)) * noise) x += ϵ return x +end + +"A simple functor to convert a vector to a convolutional layer." +struct ToConv + n_in::Int +end + +""" + (f::ToConv)(x) + +Method to convert a vector to a convolutional layer. +""" +function (f::ToConv)(x) + return reshape(x, (f.n_in, f.n_in, 1, :)) end \ No newline at end of file -- GitLab