diff --git a/experiments/benchmarking/benchmarking.jl b/experiments/benchmarking/benchmarking.jl index 74ffc5a638333e6f8f4b3da9b0636b6b684380f1..f07bd6dbbd541814d7ef9225f82cee8988ab7509 100644 --- a/experiments/benchmarking/benchmarking.jl +++ b/experiments/benchmarking/benchmarking.jl @@ -175,6 +175,7 @@ function run_benchmark(exper::Experiment, model_dict::Dict) converge_when = :generator_conditions, parallelizer = parallelizer, store_ce = exper.store_ce, + n_runs = exper.n_runs, ) return bmk, generator_dict end diff --git a/experiments/experiment.jl b/experiments/experiment.jl index 333e7478bf91a114c21f3bc72b905ae46ea2802c..b1a68c738ba71d7cbac4ba96aa0e3d8143356f8c 100644 --- a/experiments/experiment.jl +++ b/experiments/experiment.jl @@ -25,6 +25,7 @@ Base.@kwdef struct Experiment coverage::Float64 = DEFAULT_COVERAGE generators::Union{Nothing,Dict} = nothing n_individuals::Int = N_IND + n_runs::Int = N_RUNS ce_measures::AbstractArray = CE_MEASURES model_measures::Dict = MODEL_MEASURES use_class_loss::Bool = false diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl index 3a62e9a37841b638857e4fc8b3a116ee970346c5..a648675c7ff2f47d01b93c8e6f4fc5ed9fa9d1e3 100644 --- a/experiments/grid_search.jl +++ b/experiments/grid_search.jl @@ -47,6 +47,7 @@ function grid_search( kwargs..., ) + params = map(x -> typeof(x[2]) <: Vector ? x[1] => Tuple(x[2]) : x[1] => x[2], params) df_params = DataFrame(merge(Dict(:id => counter), Dict(params))) |> x -> select(x, :id, Not(:id)) diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl index 4a68251fa6c640d9b47930a4156fc584dafee7d5..a7b2108a8c89bd313c5bae222778cb9c76c6bd55 100644 --- a/experiments/setup_env.jl +++ b/experiments/setup_env.jl @@ -16,6 +16,7 @@ using DataFrames using Distributions: Normal, Distribution, Categorical, Uniform using ECCCo using Flux +using Flux.Optimise: Optimiser, Descent, Adam, ClipValue using JointEnergyModels using LazyArtifacts using Logging @@ -64,6 +65,7 @@ const N_IND = n_individuals "Boolean flag to check if number of individuals was specified." const N_IND_SPECIFIED = n_ind_specified +# Number of tasks per process: if any(contains.(ARGS, "n_each=")) n_each = ARGS[findall(contains.(ARGS, "n_each="))][1] |> @@ -75,6 +77,18 @@ end "Number of objects to pass to each process." const N_EACH = n_each +# Number of benchmark runs: +if any(contains.(ARGS, "n_runs=")) + n_runs = + ARGS[findall(contains.(ARGS, "n_runs="))][1] |> + x -> replace(x, "n_runs=" => "") |> x -> parse(Int, x) +else + n_runs = 1 +end + +"Number of benchmark runs." +const N_RUNS = n_runs + # Parallelization: plz = nothing @@ -171,19 +185,24 @@ DEFAULT_GENERATOR_TUNING = ( Λ = [[0.1, 0.1, 0.05], [0.1, 0.1, 0.1], [0.1, 0.1, 0.5], [0.1, 0.1, 1.0]], reg_strength = [0.0, 0.1, 0.25, 0.5, 1.0], opt = [ - Flux.Optimise.Descent(0.1), - Flux.Optimise.Descent(0.05), - Flux.Optimise.Descent(0.01), + Descent(0.01), + Descent(0.05), + Descent(0.1), ], - decay = [(0.0, 1), (0.1, 1), (0.5, 1)], + decay = [(0.0, 1), (0.01, 1), (0.1, 1)], ) "Generator tuning parameters for large datasets." DEFAULT_GENERATOR_TUNING_LARGE = ( Λ = [[0.1, 0.1, 0.1], [0.1, 0.1, 0.2], [0.2, 0.2, 0.2]], - reg_strength = [0.0], - opt = [Flux.Optimise.Descent(0.01), Flux.Optimise.Descent(0.05)], - decay = [(0.0, 1), (0.1, 1), (0.5, 1)], + reg_strength = [0.0, 0.1, 0.25,], + opt = [ + Descent(0.01), + Descent(0.05), + Optimiser(ClipValue(0.01), Descent(0.01)), + Optimiser(ClipValue(0.05), Descent(0.05)), + ], + decay = [(0.0, 1), (0.01, 1), (0.1, 1)], ) "Boolean flag to check if model tuning was specified."