Skip to content
Snippets Groups Projects
Commit aff3fffa authored by Pat Alt's avatar Pat Alt
Browse files

slight change to how grid results are stored

parent 8d50e7be
No related branches found
No related tags found
1 merge request!8985 overshooting
......@@ -47,6 +47,7 @@ function grid_search(
kwargs...,
)
params = map(x -> typeof(x[2]) <: Vector ? x[1] => Tuple(x[2]) : x[1] => x[2], params)
df_params =
DataFrame(merge(Dict(:id => counter), Dict(params))) |>
x -> select(x, :id, Not(:id))
......
......@@ -16,6 +16,7 @@ using DataFrames
using Distributions: Normal, Distribution, Categorical, Uniform
using ECCCo
using Flux
using Flux.Optimise: Optimiser, Descent, Adam, ClipValue
using JointEnergyModels
using LazyArtifacts
using Logging
......@@ -171,19 +172,24 @@ DEFAULT_GENERATOR_TUNING = (
Λ = [[0.1, 0.1, 0.05], [0.1, 0.1, 0.1], [0.1, 0.1, 0.5], [0.1, 0.1, 1.0]],
reg_strength = [0.0, 0.1, 0.25, 0.5, 1.0],
opt = [
Flux.Optimise.Descent(0.1),
Flux.Optimise.Descent(0.05),
Flux.Optimise.Descent(0.01),
Descent(0.01),
Descent(0.05),
Descent(0.1),
],
decay = [(0.0, 1), (0.1, 1), (0.5, 1)],
decay = [(0.0, 1), (0.01, 1), (0.1, 1)],
)
"Generator tuning parameters for large datasets."
DEFAULT_GENERATOR_TUNING_LARGE = (
Λ = [[0.1, 0.1, 0.1], [0.1, 0.1, 0.2], [0.2, 0.2, 0.2]],
reg_strength = [0.0],
opt = [Flux.Optimise.Descent(0.01), Flux.Optimise.Descent(0.05)],
decay = [(0.0, 1), (0.1, 1), (0.5, 1)],
reg_strength = [0.0, 0.1, 0.25,],
opt = [
Descent(0.01),
Descent(0.05),
Optimiser(ClipValue(0.01), Descent(0.01)),
Optimiser(ClipValue(0.05), Descent(0.05)),
],
decay = [(0.0, 1), (0.01, 1), (0.1, 1)],
)
"Boolean flag to check if model tuning was specified."
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment