diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl index fc0efebb793e20afc21c3c8df5b81ba52c827e4d..d06214f541b35bd7ca39e1a9b9fd83682da5d3d7 100644 --- a/experiments/grid_search.jl +++ b/experiments/grid_search.jl @@ -1,3 +1,5 @@ +using DataFrames + """ grid_search( couterfactual_data::CounterfactualData, @@ -24,7 +26,8 @@ function grid_search( # Grid setup: tuning_params = [Pair.(k, vals) for (k, vals) in pairs(tuning_params)] grid = Iterators.product(tuning_params...) - outcomes = Dict{Any,Any}() + df_params = [] + df_outcomes = [] # Search: counter = 1 @@ -40,10 +43,18 @@ function grid_search( tuning_params..., kwargs..., ) - outcomes[tuning_params] = outcome + _df_params = DataFrame(Dict(:id => counter, pairs(tuning_params))) + _df_outcomes = DataFrame(Dict(:id => counter, :params => tuning_params, :outcome => outcome)) + push!(df_params, _df_params) + push!(df_outcomes, _df_outcomes) counter += 1 end + outcomes = Dict( + :df_params => vcat(df_params...), + :df_outcomes => vcat(df_outcomes...), + ) + # Save: if !(is_multi_processed(PLZ) && MPI.Comm_rank(PLZ.comm) != 0) Serialization.serialize( @@ -104,7 +115,7 @@ function best_rank_outcome( df_weights = DataFrame(variable = measure, weight = weights) ranks = [] - for (params, outcome) in outcomes + for outcome in outcomes.df_outcomes.outcome _ranks = generator_rank( outcome; @@ -119,8 +130,8 @@ function best_rank_outcome( end best_index = argmin(ranks) best_outcome = ( - params = collect(keys(outcomes))[best_index], - outcome = collect(values(outcomes))[best_index], + params=outcomes.df_outcomes.params[best_index], + outcome=outcomes.df_outcomes.outcomes[best_index], ) return best_outcome end @@ -148,7 +159,7 @@ function best_absolute_outcome( df_weights = DataFrame(variable = measure, weight = weights) avg_values = [] - for (params, outcome) in outcomes + for (params, outcome) in (outcomes.df_outcomes.params, outcomes.df_outcomes.outcome) # Setup evaluation = deepcopy(outcome.bmk.evaluation) @@ -191,8 +202,8 @@ function best_absolute_outcome( end best_index = argmin(avg_values) best_outcome = ( - params = collect(keys(outcomes))[best_index], - outcome = collect(values(outcomes))[best_index], + params = outcomes.df_outcomes.params[best_index], + outcome = outcomes.df_outcomes.outcomes[best_index], ) end