From aff3fffa5bfda8d22351f6fbaa8c4df8b4423246 Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Wed, 4 Oct 2023 11:12:59 +0200 Subject: [PATCH] slight change to how grid results are stored --- experiments/grid_search.jl | 1 + experiments/setup_env.jl | 20 +++++++++++++------- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl index 3a62e9a3..a648675c 100644 --- a/experiments/grid_search.jl +++ b/experiments/grid_search.jl @@ -47,6 +47,7 @@ function grid_search( kwargs..., ) + params = map(x -> typeof(x[2]) <: Vector ? x[1] => Tuple(x[2]) : x[1] => x[2], params) df_params = DataFrame(merge(Dict(:id => counter), Dict(params))) |> x -> select(x, :id, Not(:id)) diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl index 4a68251f..29a64d54 100644 --- a/experiments/setup_env.jl +++ b/experiments/setup_env.jl @@ -16,6 +16,7 @@ using DataFrames using Distributions: Normal, Distribution, Categorical, Uniform using ECCCo using Flux +using Flux.Optimise: Optimiser, Descent, Adam, ClipValue using JointEnergyModels using LazyArtifacts using Logging @@ -171,19 +172,24 @@ DEFAULT_GENERATOR_TUNING = ( Λ = [[0.1, 0.1, 0.05], [0.1, 0.1, 0.1], [0.1, 0.1, 0.5], [0.1, 0.1, 1.0]], reg_strength = [0.0, 0.1, 0.25, 0.5, 1.0], opt = [ - Flux.Optimise.Descent(0.1), - Flux.Optimise.Descent(0.05), - Flux.Optimise.Descent(0.01), + Descent(0.01), + Descent(0.05), + Descent(0.1), ], - decay = [(0.0, 1), (0.1, 1), (0.5, 1)], + decay = [(0.0, 1), (0.01, 1), (0.1, 1)], ) "Generator tuning parameters for large datasets." DEFAULT_GENERATOR_TUNING_LARGE = ( Λ = [[0.1, 0.1, 0.1], [0.1, 0.1, 0.2], [0.2, 0.2, 0.2]], - reg_strength = [0.0], - opt = [Flux.Optimise.Descent(0.01), Flux.Optimise.Descent(0.05)], - decay = [(0.0, 1), (0.1, 1), (0.5, 1)], + reg_strength = [0.0, 0.1, 0.25,], + opt = [ + Descent(0.01), + Descent(0.05), + Optimiser(ClipValue(0.01), Descent(0.01)), + Optimiser(ClipValue(0.05), Descent(0.05)), + ], + decay = [(0.0, 1), (0.01, 1), (0.1, 1)], ) "Boolean flag to check if model tuning was specified." -- GitLab