From b674b4a1ea7e26205af59119e0e1c2e9a156a211 Mon Sep 17 00:00:00 2001 From: pat-alt <altmeyerpat@gmail.com> Date: Wed, 27 Sep 2023 13:03:14 +0200 Subject: [PATCH] rejigged grid search a little --- experiments/grid_search.jl | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl index fc0efebb..d06214f5 100644 --- a/experiments/grid_search.jl +++ b/experiments/grid_search.jl @@ -1,3 +1,5 @@ +using DataFrames + """ grid_search( couterfactual_data::CounterfactualData, @@ -24,7 +26,8 @@ function grid_search( # Grid setup: tuning_params = [Pair.(k, vals) for (k, vals) in pairs(tuning_params)] grid = Iterators.product(tuning_params...) - outcomes = Dict{Any,Any}() + df_params = [] + df_outcomes = [] # Search: counter = 1 @@ -40,10 +43,18 @@ function grid_search( tuning_params..., kwargs..., ) - outcomes[tuning_params] = outcome + _df_params = DataFrame(Dict(:id => counter, pairs(tuning_params))) + _df_outcomes = DataFrame(Dict(:id => counter, :params => tuning_params, :outcome => outcome)) + push!(df_params, _df_params) + push!(df_outcomes, _df_outcomes) counter += 1 end + outcomes = Dict( + :df_params => vcat(df_params...), + :df_outcomes => vcat(df_outcomes...), + ) + # Save: if !(is_multi_processed(PLZ) && MPI.Comm_rank(PLZ.comm) != 0) Serialization.serialize( @@ -104,7 +115,7 @@ function best_rank_outcome( df_weights = DataFrame(variable = measure, weight = weights) ranks = [] - for (params, outcome) in outcomes + for outcome in outcomes.df_outcomes.outcome _ranks = generator_rank( outcome; @@ -119,8 +130,8 @@ function best_rank_outcome( end best_index = argmin(ranks) best_outcome = ( - params = collect(keys(outcomes))[best_index], - outcome = collect(values(outcomes))[best_index], + params=outcomes.df_outcomes.params[best_index], + outcome=outcomes.df_outcomes.outcomes[best_index], ) return best_outcome end @@ -148,7 +159,7 @@ function best_absolute_outcome( df_weights = DataFrame(variable = measure, weight = weights) avg_values = [] - for (params, outcome) in outcomes + for (params, outcome) in (outcomes.df_outcomes.params, outcomes.df_outcomes.outcome) # Setup evaluation = deepcopy(outcome.bmk.evaluation) @@ -191,8 +202,8 @@ function best_absolute_outcome( end best_index = argmin(avg_values) best_outcome = ( - params = collect(keys(outcomes))[best_index], - outcome = collect(values(outcomes))[best_index], + params = outcomes.df_outcomes.params[best_index], + outcome = outcomes.df_outcomes.outcomes[best_index], ) end -- GitLab