Skip to content
Snippets Groups Projects
Commit 945d2e5e authored by pat-alt's avatar pat-alt
Browse files

Merge branch '85-overshooting' of https://github.com/pat-alt/ECCCo.jl into 85-overshooting

parents 3ec4339c 4a21070e
No related branches found
No related tags found
1 merge request!8985 overshooting
......@@ -175,6 +175,7 @@ function run_benchmark(exper::Experiment, model_dict::Dict)
converge_when = :generator_conditions,
parallelizer = parallelizer,
store_ce = exper.store_ce,
n_runs = exper.n_runs,
)
return bmk, generator_dict
end
......@@ -25,6 +25,7 @@ Base.@kwdef struct Experiment
coverage::Float64 = DEFAULT_COVERAGE
generators::Union{Nothing,Dict} = nothing
n_individuals::Int = N_IND
n_runs::Int = N_RUNS
ce_measures::AbstractArray = CE_MEASURES
model_measures::Dict = MODEL_MEASURES
use_class_loss::Bool = false
......
......@@ -47,6 +47,7 @@ function grid_search(
kwargs...,
)
params = map(x -> typeof(x[2]) <: Vector ? x[1] => Tuple(x[2]) : x[1] => x[2], params)
df_params =
DataFrame(merge(Dict(:id => counter), Dict(params))) |>
x -> select(x, :id, Not(:id))
......
......@@ -16,6 +16,7 @@ using DataFrames
using Distributions: Normal, Distribution, Categorical, Uniform
using ECCCo
using Flux
using Flux.Optimise: Optimiser, Descent, Adam, ClipValue
using JointEnergyModels
using LazyArtifacts
using Logging
......@@ -64,6 +65,7 @@ const N_IND = n_individuals
"Boolean flag to check if number of individuals was specified."
const N_IND_SPECIFIED = n_ind_specified
# Number of tasks per process:
if any(contains.(ARGS, "n_each="))
n_each =
ARGS[findall(contains.(ARGS, "n_each="))][1] |>
......@@ -75,6 +77,18 @@ end
"Number of objects to pass to each process."
const N_EACH = n_each
# Number of benchmark runs:
if any(contains.(ARGS, "n_runs="))
n_runs =
ARGS[findall(contains.(ARGS, "n_runs="))][1] |>
x -> replace(x, "n_runs=" => "") |> x -> parse(Int, x)
else
n_runs = 1
end
"Number of benchmark runs."
const N_RUNS = n_runs
# Parallelization:
plz = nothing
......@@ -171,19 +185,24 @@ DEFAULT_GENERATOR_TUNING = (
Λ = [[0.1, 0.1, 0.05], [0.1, 0.1, 0.1], [0.1, 0.1, 0.5], [0.1, 0.1, 1.0]],
reg_strength = [0.0, 0.1, 0.25, 0.5, 1.0],
opt = [
Flux.Optimise.Descent(0.1),
Flux.Optimise.Descent(0.05),
Flux.Optimise.Descent(0.01),
Descent(0.01),
Descent(0.05),
Descent(0.1),
],
decay = [(0.0, 1), (0.1, 1), (0.5, 1)],
decay = [(0.0, 1), (0.01, 1), (0.1, 1)],
)
"Generator tuning parameters for large datasets."
DEFAULT_GENERATOR_TUNING_LARGE = (
Λ = [[0.1, 0.1, 0.1], [0.1, 0.1, 0.2], [0.2, 0.2, 0.2]],
reg_strength = [0.0],
opt = [Flux.Optimise.Descent(0.01), Flux.Optimise.Descent(0.05)],
decay = [(0.0, 1), (0.1, 1), (0.5, 1)],
reg_strength = [0.0, 0.1, 0.25,],
opt = [
Descent(0.01),
Descent(0.05),
Optimiser(ClipValue(0.01), Descent(0.01)),
Optimiser(ClipValue(0.05), Descent(0.05)),
],
decay = [(0.0, 1), (0.01, 1), (0.1, 1)],
)
"Boolean flag to check if model tuning was specified."
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment