From a1aff09648ded0f8641998ab47ad0f943e3d20ed Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Wed, 6 Sep 2023 21:14:51 +0200
Subject: [PATCH] more things

---
 experiments/mnist.jl                    |  2 +-
 experiments/models/additional_models.jl |  3 +--
 src/utils.jl                            | 19 +++++++++++++++++++
 3 files changed, 21 insertions(+), 3 deletions(-)

diff --git a/experiments/mnist.jl b/experiments/mnist.jl
index a3fbbaae..627e1d25 100644
--- a/experiments/mnist.jl
+++ b/experiments/mnist.jl
@@ -13,7 +13,7 @@ test_data = load_mnist_test()
 
 # Additional models:
 add_models = Dict(
-    :lenet5 => lenet5,
+    "LeNet-5" => lenet5,
 )
 
 # Default builder:
diff --git a/experiments/models/additional_models.jl b/experiments/models/additional_models.jl
index a5b354d6..f6ed7451 100644
--- a/experiments/models/additional_models.jl
+++ b/experiments/models/additional_models.jl
@@ -21,7 +21,6 @@ function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out)
 	k, c1, c2 = b.filter_size, b.channels1, b.channels2
 	mod(k, 2) == 1 || error("`filter_size` must be odd. ")
     p = div(k - 1, 2) # padding to preserve image size on convolution:
-    preproc(x) = reshape(x, (_n_in, _n_in, 1, :))
 
     # Model:
 	front = Flux.Chain(
@@ -37,7 +36,7 @@ function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out)
         Dense(120, 84, relu),
         Dense(84, n_out),
     )
-    chain = Flux.Chain(preproc, front, back)
+    chain = Flux.Chain(ECCCo.ToConv(_n_in), front, back)
 
 	return chain
 end
diff --git a/src/utils.jl b/src/utils.jl
index 968fb26e..87c4d0a7 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -1,5 +1,24 @@
+"""
+    pre_process(x; noise=0.03f0)
+
+Helper function to add tiny noise to inputs.
+"""
 function pre_process(x; noise::Float32=0.03f0)
     ϵ = Float32.(randn(size(x)) * noise)
     x += ϵ
     return x
+end
+
+"A simple functor to convert a vector to a convolutional layer."
+struct ToConv
+    n_in::Int
+end
+
+"""
+    (f::ToConv)(x)
+
+Method to convert a vector to a convolutional layer.
+"""
+function (f::ToConv)(x)
+    return reshape(x, (f.n_in, f.n_in, 1, :))
 end
\ No newline at end of file
-- 
GitLab