Skip to content
Snippets Groups Projects
Commit 10742496 authored by Pat Alt's avatar Pat Alt
Browse files

fix:

parent 73db7284
No related branches found
No related tags found
1 merge request!8985 overshooting
...@@ -65,7 +65,7 @@ function default_generators(; ...@@ -65,7 +65,7 @@ function default_generators(;
nmin = nmin, nmin = nmin,
niter = niter_eccco, niter = niter_eccco,
reg_strength = reg_strength, reg_strength = reg_strength,
decay=decay, decay = decay,
), ),
"ECCCo-Δ (no EBM)" => ECCCoGenerator( "ECCCo-Δ (no EBM)" => ECCCoGenerator(
λ = [λ₁_Δ, λ₂_Δ, 0.0], λ = [λ₁_Δ, λ₂_Δ, 0.0],
...@@ -76,7 +76,7 @@ function default_generators(; ...@@ -76,7 +76,7 @@ function default_generators(;
nmin = nmin, nmin = nmin,
niter = niter_eccco, niter = niter_eccco,
reg_strength = reg_strength, reg_strength = reg_strength,
decay=decay, decay = decay,
), ),
) )
else else
...@@ -101,7 +101,7 @@ function default_generators(; ...@@ -101,7 +101,7 @@ function default_generators(;
nmin = nmin, nmin = nmin,
niter = niter_eccco, niter = niter_eccco,
reg_strength = reg_strength, reg_strength = reg_strength,
decay=decay, decay = decay,
), ),
) )
end end
...@@ -119,7 +119,7 @@ function default_generators(; ...@@ -119,7 +119,7 @@ function default_generators(;
nmin = nmin, nmin = nmin,
niter = niter_eccco, niter = niter_eccco,
reg_strength = reg_strength, reg_strength = reg_strength,
decay=decay, decay = decay,
dim_reduction = dim_reduction, dim_reduction = dim_reduction,
), ),
) )
......
...@@ -31,8 +31,8 @@ function grid_search( ...@@ -31,8 +31,8 @@ function grid_search(
# Search: # Search:
counter = 1 counter = 1
for tuning_params in grid for params in grid
@info "Running experiment $(counter)/$(length(grid)) with tuning parameters: $(tuning_params)" @info "Running experiment $(counter)/$(length(grid)) with tuning parameters: $(params)"
outcome = run_experiment( outcome = run_experiment(
counterfactual_data, counterfactual_data,
test_data; test_data;
...@@ -40,20 +40,22 @@ function grid_search( ...@@ -40,20 +40,22 @@ function grid_search(
dataname = dataname, dataname = dataname,
n_individuals = n_individuals, n_individuals = n_individuals,
output_path = grid_search_path, output_path = grid_search_path,
tuning_params..., params...,
kwargs..., kwargs...,
) )
_df_params = DataFrame(merge(Dict(:id => counter), Dict(pairs(tuning_params))))
_df_outcomes = DataFrame(Dict(:id => counter, :params => tuning_params, :outcome => outcome)) _df_params =
DataFrame(merge(Dict(:id => counter), Dict(params))) |>
x -> select(x, :id, Not(:id))
_df_outcomes =
DataFrame(Dict(:id => counter, :params => params, :outcome => outcome)) |>
x -> select(x, :id, Not(:id))
push!(df_params, _df_params) push!(df_params, _df_params)
push!(df_outcomes, _df_outcomes) push!(df_outcomes, _df_outcomes)
counter += 1 counter += 1
end end
outcomes = Dict( outcomes = Dict(:df_params => vcat(df_params...), :df_outcomes => vcat(df_outcomes...))
:df_params => vcat(df_params...),
:df_outcomes => vcat(df_outcomes...),
)
# Save: # Save:
if !(is_multi_processed(PLZ) && MPI.Comm_rank(PLZ.comm) != 0) if !(is_multi_processed(PLZ) && MPI.Comm_rank(PLZ.comm) != 0)
...@@ -130,8 +132,8 @@ function best_rank_outcome( ...@@ -130,8 +132,8 @@ function best_rank_outcome(
end end
best_index = argmin(ranks) best_index = argmin(ranks)
best_outcome = ( best_outcome = (
params=outcomes.df_outcomes.params[best_index], params = outcomes.df_outcomes.params[best_index],
outcome=outcomes.df_outcomes.outcomes[best_index], outcome = outcomes.df_outcomes.outcomes[best_index],
) )
return best_outcome return best_outcome
end end
......
...@@ -163,7 +163,7 @@ DEFAULT_GENERATOR_TUNING = ( ...@@ -163,7 +163,7 @@ DEFAULT_GENERATOR_TUNING = (
Flux.Optimise.Descent(0.05), Flux.Optimise.Descent(0.05),
Flux.Optimise.Descent(0.01), Flux.Optimise.Descent(0.01),
], ],
decay=[(0.0, 1), (0.1, 1), (0.5, 1),], decay = [(0.0, 1), (0.1, 1), (0.5, 1)],
) )
"Generator tuning parameters for large datasets." "Generator tuning parameters for large datasets."
...@@ -171,7 +171,7 @@ DEFAULT_GENERATOR_TUNING_LARGE = ( ...@@ -171,7 +171,7 @@ DEFAULT_GENERATOR_TUNING_LARGE = (
Λ = [[0.1, 0.1, 0.1], [0.1, 0.1, 0.2], [0.2, 0.2, 0.2]], Λ = [[0.1, 0.1, 0.1], [0.1, 0.1, 0.2], [0.2, 0.2, 0.2]],
reg_strength = [0.0], reg_strength = [0.0],
opt = [Flux.Optimise.Descent(0.01), Flux.Optimise.Descent(0.05)], opt = [Flux.Optimise.Descent(0.01), Flux.Optimise.Descent(0.05)],
decay = [(0.0, 1), (0.1, 1), (0.5, 1),], decay = [(0.0, 1), (0.1, 1), (0.5, 1)],
) )
"Boolean flag to check if model tuning was specified." "Boolean flag to check if model tuning was specified."
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment