From cdc331f71ea3b172a6130ab9a90b71be2f0bdde2 Mon Sep 17 00:00:00 2001 From: pat-alt <altmeyerpat@gmail.com> Date: Fri, 29 Sep 2023 15:07:33 +0200 Subject: [PATCH] moving grid search results into temporary storage --- experiments/grid_search.jl | 38 ++++++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl index e2f66709..5bdb5aaa 100644 --- a/experiments/grid_search.jl +++ b/experiments/grid_search.jl @@ -26,8 +26,11 @@ function grid_search( # Grid setup: tuning_params = [Pair.(k, vals) for (k, vals) in pairs(tuning_params)] grid = Iterators.product(tuning_params...) - df_params = [] - df_outcomes = [] + + # Temporary storage on disk: + storage_path = joinpath(grid_search_path, ".tmp_results_$(replace(lowercase(dataname), " " => "_"))") + mkpath(storage_path) + @info "Storing temporary results in $(storage_path)." # Search: counter = 1 @@ -44,21 +47,39 @@ function grid_search( kwargs..., ) - _df_params = + df_params = DataFrame(merge(Dict(:id => counter), Dict(params))) |> x -> select(x, :id, Not(:id)) - _df_outcomes = + df_outcomes = DataFrame(Dict(:id => counter, :params => params, :outcome => outcome)) |> x -> select(x, :id, Not(:id)) - push!(df_params, _df_params) - push!(df_outcomes, _df_outcomes) + + # Save: + if !(is_multi_processed(PLZ) && MPI.Comm_rank(PLZ.comm) != 0) + Serialization.serialize( + joinpath(storage_path, "$(params)_$(counter).jls"), + df_params, + ) + Serialization.serialize( + joinpath(storage_path, "$(outcomes)_$(counter).jls"), + df_outcomes, + ) + end counter += 1 end - outcomes = Dict(:df_params => vcat(df_params...), :df_outcomes => vcat(df_outcomes...)) - # Save: if !(is_multi_processed(PLZ) && MPI.Comm_rank(PLZ.comm) != 0) + + # Deserialise: + df_params = [] + df_outcomes = [] + for i in 1:length(grid) + df_params = push!(df_params, Serialization.deserialize(joinpath(storage_path_params, "$(i).jls"))) + df_outcomes = push!(df_outcomes; Serialization.deserialize(joinpath(storage_path_outcomes, "$(i).jls"))) + end + outcomes = Dict(:df_params => vcat(df_params...), :df_outcomes => vcat(df_outcomes...)) + Serialization.serialize( joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_")).jls"), outcomes, @@ -85,6 +106,7 @@ function grid_search( best_absolute_outcome_eccco_Δ(outcomes), ) end + end const ALL_ECCCO_NAMES = [ -- GitLab