diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl index bfb20a3fa71bcb8e551e4a15c733abb44c481ab4..1bfda483912326e678052e7b2a00007c3b8d5a7f 100644 --- a/experiments/grid_search.jl +++ b/experiments/grid_search.jl @@ -47,8 +47,8 @@ function grid_search( if !(is_multi_processed(PLZ) && MPI.Comm_rank(PLZ.comm) != 0) Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_")).jls"), outcomes) Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_"))_best.jls"), best_outcome(outcomes)) - Serialization.serialise(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_"))_best_eccco.jls"), best_eccco(outcomes)) - Serialization.serialise(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_"))_best_eccco_Δ.jls"), best_eccco_Δ(outcomes)) + Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_"))_best_eccco.jls"), best_eccco(outcomes)) + Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_"))_best_eccco_Δ.jls"), best_eccco_Δ(outcomes)) end end diff --git a/experiments/mnist.jl b/experiments/mnist.jl index f7fb962229536739dc320641478657863da6175e..34aae4c3979cfef34277a96f1a6692631f0917ab 100644 --- a/experiments/mnist.jl +++ b/experiments/mnist.jl @@ -38,10 +38,10 @@ params = ( epochs=10, nsamples=50, nmin=1, - niter_eccco=10, + niter_eccco=500, Λ=[0.1, 0.25, 0.25], Λ_Δ=[0.1, 0.1, 2.5], - opt=Flux.Optimise.Descent(0.1) + opt=Flux.Optimise.Adam(0.1) ) if !GRID_SEARCH