Skip to content
Snippets Groups Projects
Commit 0b78c792 authored by pat-alt's avatar pat-alt
Browse files

let's try this again then

parent 0b75be12
No related branches found
No related tags found
1 merge request!8985 overshooting
......@@ -40,7 +40,6 @@ params = (
opt = Flux.Optimise.Descent(0.05),
Λ = [0.1, 0.1, 0.1],
reg_strength = 0.0,
n_individuals = 25,
dim_reduction = true,
)
......
......@@ -40,7 +40,6 @@ params = (
opt = Flux.Optimise.Descent(0.05),
Λ = [0.2, 0.2, 0.2],
reg_strength = 0.5,
n_individuals = 25,
dim_reduction = true,
)
......
......@@ -40,7 +40,6 @@ params = (
opt = Flux.Optimise.Descent(0.05),
Λ = [0.1, 0.1, 0.1],
reg_strength = 0.0,
n_individuals = 25,
dim_reduction = true,
)
......
......@@ -36,17 +36,25 @@ function grid_search(
counter = 1
for params in grid
@info "Running experiment $(counter)/$(length(grid)) with tuning parameters: $(params)"
# Filter out keyword parameters that are tuned:
not_these = keys(kwargs)[findall([k in map(k -> k[1], params) for k in keys(kwargs)])]
not_these = (not_these..., :n_individuals)
kwargs = filter(x -> !(x[1] not_these), Base.Pairs(params, keys(kwargs)))
# Run experiment:
outcome = run_experiment(
counterfactual_data,
test_data;
save_output = false,
dataname = dataname,
n_individuals = n_individuals,
output_path = grid_search_path,
params...,
kwargs...,
save_output=false,
dataname=dataname,
n_individuals=n_individuals,
output_path=grid_search_path
)
# Collect:
params = map(x -> typeof(x[2]) <: Vector ? x[1] => Tuple(x[2]) : x[1] => x[2], params)
df_params =
DataFrame(merge(Dict(:id => counter), Dict(params))) |>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment