Skip to content
Snippets Groups Projects
Commit b674b4a1 authored by pat-alt's avatar pat-alt
Browse files

rejigged grid search a little

parent d478bcca
No related branches found
No related tags found
1 merge request!8985 overshooting
using DataFrames
""" """
grid_search( grid_search(
couterfactual_data::CounterfactualData, couterfactual_data::CounterfactualData,
...@@ -24,7 +26,8 @@ function grid_search( ...@@ -24,7 +26,8 @@ function grid_search(
# Grid setup: # Grid setup:
tuning_params = [Pair.(k, vals) for (k, vals) in pairs(tuning_params)] tuning_params = [Pair.(k, vals) for (k, vals) in pairs(tuning_params)]
grid = Iterators.product(tuning_params...) grid = Iterators.product(tuning_params...)
outcomes = Dict{Any,Any}() df_params = []
df_outcomes = []
# Search: # Search:
counter = 1 counter = 1
...@@ -40,10 +43,18 @@ function grid_search( ...@@ -40,10 +43,18 @@ function grid_search(
tuning_params..., tuning_params...,
kwargs..., kwargs...,
) )
outcomes[tuning_params] = outcome _df_params = DataFrame(Dict(:id => counter, pairs(tuning_params)))
_df_outcomes = DataFrame(Dict(:id => counter, :params => tuning_params, :outcome => outcome))
push!(df_params, _df_params)
push!(df_outcomes, _df_outcomes)
counter += 1 counter += 1
end end
outcomes = Dict(
:df_params => vcat(df_params...),
:df_outcomes => vcat(df_outcomes...),
)
# Save: # Save:
if !(is_multi_processed(PLZ) && MPI.Comm_rank(PLZ.comm) != 0) if !(is_multi_processed(PLZ) && MPI.Comm_rank(PLZ.comm) != 0)
Serialization.serialize( Serialization.serialize(
...@@ -104,7 +115,7 @@ function best_rank_outcome( ...@@ -104,7 +115,7 @@ function best_rank_outcome(
df_weights = DataFrame(variable = measure, weight = weights) df_weights = DataFrame(variable = measure, weight = weights)
ranks = [] ranks = []
for (params, outcome) in outcomes for outcome in outcomes.df_outcomes.outcome
_ranks = _ranks =
generator_rank( generator_rank(
outcome; outcome;
...@@ -119,8 +130,8 @@ function best_rank_outcome( ...@@ -119,8 +130,8 @@ function best_rank_outcome(
end end
best_index = argmin(ranks) best_index = argmin(ranks)
best_outcome = ( best_outcome = (
params = collect(keys(outcomes))[best_index], params=outcomes.df_outcomes.params[best_index],
outcome = collect(values(outcomes))[best_index], outcome=outcomes.df_outcomes.outcomes[best_index],
) )
return best_outcome return best_outcome
end end
...@@ -148,7 +159,7 @@ function best_absolute_outcome( ...@@ -148,7 +159,7 @@ function best_absolute_outcome(
df_weights = DataFrame(variable = measure, weight = weights) df_weights = DataFrame(variable = measure, weight = weights)
avg_values = [] avg_values = []
for (params, outcome) in outcomes for (params, outcome) in (outcomes.df_outcomes.params, outcomes.df_outcomes.outcome)
# Setup # Setup
evaluation = deepcopy(outcome.bmk.evaluation) evaluation = deepcopy(outcome.bmk.evaluation)
...@@ -191,8 +202,8 @@ function best_absolute_outcome( ...@@ -191,8 +202,8 @@ function best_absolute_outcome(
end end
best_index = argmin(avg_values) best_index = argmin(avg_values)
best_outcome = ( best_outcome = (
params = collect(keys(outcomes))[best_index], params = outcomes.df_outcomes.params[best_index],
outcome = collect(values(outcomes))[best_index], outcome = outcomes.df_outcomes.outcomes[best_index],
) )
end end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment