diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl index 3a62e9a37841b638857e4fc8b3a116ee970346c5..a648675c7ff2f47d01b93c8e6f4fc5ed9fa9d1e3 100644 --- a/experiments/grid_search.jl +++ b/experiments/grid_search.jl @@ -47,6 +47,7 @@ function grid_search( kwargs..., ) + params = map(x -> typeof(x[2]) <: Vector ? x[1] => Tuple(x[2]) : x[1] => x[2], params) df_params = DataFrame(merge(Dict(:id => counter), Dict(params))) |> x -> select(x, :id, Not(:id)) diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl index 4a68251fa6c640d9b47930a4156fc584dafee7d5..29a64d543666705e273f7a36e4f48db767d72ab0 100644 --- a/experiments/setup_env.jl +++ b/experiments/setup_env.jl @@ -16,6 +16,7 @@ using DataFrames using Distributions: Normal, Distribution, Categorical, Uniform using ECCCo using Flux +using Flux.Optimise: Optimiser, Descent, Adam, ClipValue using JointEnergyModels using LazyArtifacts using Logging @@ -171,19 +172,24 @@ DEFAULT_GENERATOR_TUNING = ( Λ = [[0.1, 0.1, 0.05], [0.1, 0.1, 0.1], [0.1, 0.1, 0.5], [0.1, 0.1, 1.0]], reg_strength = [0.0, 0.1, 0.25, 0.5, 1.0], opt = [ - Flux.Optimise.Descent(0.1), - Flux.Optimise.Descent(0.05), - Flux.Optimise.Descent(0.01), + Descent(0.01), + Descent(0.05), + Descent(0.1), ], - decay = [(0.0, 1), (0.1, 1), (0.5, 1)], + decay = [(0.0, 1), (0.01, 1), (0.1, 1)], ) "Generator tuning parameters for large datasets." DEFAULT_GENERATOR_TUNING_LARGE = ( Λ = [[0.1, 0.1, 0.1], [0.1, 0.1, 0.2], [0.2, 0.2, 0.2]], - reg_strength = [0.0], - opt = [Flux.Optimise.Descent(0.01), Flux.Optimise.Descent(0.05)], - decay = [(0.0, 1), (0.1, 1), (0.5, 1)], + reg_strength = [0.0, 0.1, 0.25,], + opt = [ + Descent(0.01), + Descent(0.05), + Optimiser(ClipValue(0.01), Descent(0.01)), + Optimiser(ClipValue(0.05), Descent(0.05)), + ], + decay = [(0.0, 1), (0.01, 1), (0.1, 1)], ) "Boolean flag to check if model tuning was specified."