Skip to content
Snippets Groups Projects
Commit 10742496 authored by Pat Alt's avatar Pat Alt
Browse files

fix:

parent 73db7284
No related branches found
No related tags found
1 merge request!8985 overshooting
......@@ -65,7 +65,7 @@ function default_generators(;
nmin = nmin,
niter = niter_eccco,
reg_strength = reg_strength,
decay=decay,
decay = decay,
),
"ECCCo-Δ (no EBM)" => ECCCoGenerator(
λ = [λ₁_Δ, λ₂_Δ, 0.0],
......@@ -76,7 +76,7 @@ function default_generators(;
nmin = nmin,
niter = niter_eccco,
reg_strength = reg_strength,
decay=decay,
decay = decay,
),
)
else
......@@ -101,7 +101,7 @@ function default_generators(;
nmin = nmin,
niter = niter_eccco,
reg_strength = reg_strength,
decay=decay,
decay = decay,
),
)
end
......@@ -119,7 +119,7 @@ function default_generators(;
nmin = nmin,
niter = niter_eccco,
reg_strength = reg_strength,
decay=decay,
decay = decay,
dim_reduction = dim_reduction,
),
)
......
......@@ -31,8 +31,8 @@ function grid_search(
# Search:
counter = 1
for tuning_params in grid
@info "Running experiment $(counter)/$(length(grid)) with tuning parameters: $(tuning_params)"
for params in grid
@info "Running experiment $(counter)/$(length(grid)) with tuning parameters: $(params)"
outcome = run_experiment(
counterfactual_data,
test_data;
......@@ -40,20 +40,22 @@ function grid_search(
dataname = dataname,
n_individuals = n_individuals,
output_path = grid_search_path,
tuning_params...,
params...,
kwargs...,
)
_df_params = DataFrame(merge(Dict(:id => counter), Dict(pairs(tuning_params))))
_df_outcomes = DataFrame(Dict(:id => counter, :params => tuning_params, :outcome => outcome))
_df_params =
DataFrame(merge(Dict(:id => counter), Dict(params))) |>
x -> select(x, :id, Not(:id))
_df_outcomes =
DataFrame(Dict(:id => counter, :params => params, :outcome => outcome)) |>
x -> select(x, :id, Not(:id))
push!(df_params, _df_params)
push!(df_outcomes, _df_outcomes)
counter += 1
end
outcomes = Dict(
:df_params => vcat(df_params...),
:df_outcomes => vcat(df_outcomes...),
)
outcomes = Dict(:df_params => vcat(df_params...), :df_outcomes => vcat(df_outcomes...))
# Save:
if !(is_multi_processed(PLZ) && MPI.Comm_rank(PLZ.comm) != 0)
......@@ -130,8 +132,8 @@ function best_rank_outcome(
end
best_index = argmin(ranks)
best_outcome = (
params=outcomes.df_outcomes.params[best_index],
outcome=outcomes.df_outcomes.outcomes[best_index],
params = outcomes.df_outcomes.params[best_index],
outcome = outcomes.df_outcomes.outcomes[best_index],
)
return best_outcome
end
......
......@@ -163,7 +163,7 @@ DEFAULT_GENERATOR_TUNING = (
Flux.Optimise.Descent(0.05),
Flux.Optimise.Descent(0.01),
],
decay=[(0.0, 1), (0.1, 1), (0.5, 1),],
decay = [(0.0, 1), (0.1, 1), (0.5, 1)],
)
"Generator tuning parameters for large datasets."
......@@ -171,7 +171,7 @@ DEFAULT_GENERATOR_TUNING_LARGE = (
Λ = [[0.1, 0.1, 0.1], [0.1, 0.1, 0.2], [0.2, 0.2, 0.2]],
reg_strength = [0.0],
opt = [Flux.Optimise.Descent(0.01), Flux.Optimise.Descent(0.05)],
decay = [(0.0, 1), (0.1, 1), (0.5, 1),],
decay = [(0.0, 1), (0.1, 1), (0.5, 1)],
)
"Boolean flag to check if model tuning was specified."
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment