Skip to content
Snippets Groups Projects
Commit b674b4a1 authored by pat-alt's avatar pat-alt
Browse files

rejigged grid search a little

parent d478bcca
No related branches found
No related tags found
1 merge request!8985 overshooting
using DataFrames
"""
grid_search(
couterfactual_data::CounterfactualData,
......@@ -24,7 +26,8 @@ function grid_search(
# Grid setup:
tuning_params = [Pair.(k, vals) for (k, vals) in pairs(tuning_params)]
grid = Iterators.product(tuning_params...)
outcomes = Dict{Any,Any}()
df_params = []
df_outcomes = []
# Search:
counter = 1
......@@ -40,10 +43,18 @@ function grid_search(
tuning_params...,
kwargs...,
)
outcomes[tuning_params] = outcome
_df_params = DataFrame(Dict(:id => counter, pairs(tuning_params)))
_df_outcomes = DataFrame(Dict(:id => counter, :params => tuning_params, :outcome => outcome))
push!(df_params, _df_params)
push!(df_outcomes, _df_outcomes)
counter += 1
end
outcomes = Dict(
:df_params => vcat(df_params...),
:df_outcomes => vcat(df_outcomes...),
)
# Save:
if !(is_multi_processed(PLZ) && MPI.Comm_rank(PLZ.comm) != 0)
Serialization.serialize(
......@@ -104,7 +115,7 @@ function best_rank_outcome(
df_weights = DataFrame(variable = measure, weight = weights)
ranks = []
for (params, outcome) in outcomes
for outcome in outcomes.df_outcomes.outcome
_ranks =
generator_rank(
outcome;
......@@ -119,8 +130,8 @@ function best_rank_outcome(
end
best_index = argmin(ranks)
best_outcome = (
params = collect(keys(outcomes))[best_index],
outcome = collect(values(outcomes))[best_index],
params=outcomes.df_outcomes.params[best_index],
outcome=outcomes.df_outcomes.outcomes[best_index],
)
return best_outcome
end
......@@ -148,7 +159,7 @@ function best_absolute_outcome(
df_weights = DataFrame(variable = measure, weight = weights)
avg_values = []
for (params, outcome) in outcomes
for (params, outcome) in (outcomes.df_outcomes.params, outcomes.df_outcomes.outcome)
# Setup
evaluation = deepcopy(outcome.bmk.evaluation)
......@@ -191,8 +202,8 @@ function best_absolute_outcome(
end
best_index = argmin(avg_values)
best_outcome = (
params = collect(keys(outcomes))[best_index],
outcome = collect(values(outcomes))[best_index],
params = outcomes.df_outcomes.params[best_index],
outcome = outcomes.df_outcomes.outcomes[best_index],
)
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment