Skip to content
Snippets Groups Projects
Commit f35a4842 authored by pat-alt's avatar pat-alt
Browse files

enable warm restart

parent 17a6283a
No related branches found
No related tags found
1 merge request!8985 overshooting
......@@ -4,6 +4,7 @@ using DataFrames
grid_search(
couterfactual_data::CounterfactualData,
test_data::CounerfactualData;
warm_start::Bool = true,
dataname::String,
tuning_params::NamedTuple,
kwargs...,
......@@ -14,6 +15,7 @@ Perform a grid search over the hyperparameters specified by `tuning_params`. Exp
function grid_search(
couterfactual_data::CounterfactualData,
test_data::CounterfactualData;
warm_start::Bool = true,
dataname::String,
n_individuals::Int = N_IND,
tuning_params::NamedTuple,
......@@ -21,19 +23,37 @@ function grid_search(
)
# Output path:
grid_search_path = mkpath(joinpath(DEFAULT_OUTPUT_PATH, "grid_search"))
grid_search_path = joinpath(DEFAULT_OUTPUT_PATH, "grid_search")
if !isdir(grid_search_path)
mkpath(grid_search_path)
end
# Grid setup:
tuning_params = [Pair.(k, vals) for (k, vals) in pairs(tuning_params)]
grid = Iterators.product(tuning_params...)
n_total = length(grid)
# Temporary storage on disk:
storage_path = joinpath(grid_search_path, ".tmp_results_$(replace(lowercase(dataname), " " => "_"))")
mkpath(storage_path)
if !isdir(storage_path)
mkpath(storage_path)
end
@info "Storing temporary results in $(storage_path)."
# Warm start:
if warm_start
existing_files = readdir(storage_path)
n_files = Int(floor(length(existing_files) / 2))
if n_files > 0
@info "Warm start: $(n_files) existing results found."
grid = Iterators.drop(grid, n_files)
end
counter = n_files + 1
else
counter = 1
end
# Search:
counter = 1
for params in grid
@info "Running experiment $(counter)/$(length(grid)) with tuning parameters: $(params)"
......@@ -84,7 +104,7 @@ function grid_search(
# Deserialise:
df_params = []
df_outcomes = []
for i in 1:length(grid)
for i in 1:n_total
push!(df_params, Serialization.deserialize(joinpath(storage_path, "params_$(i).jls")))
push!(df_outcomes, Serialization.deserialize(joinpath(storage_path, "outcomes_$(i).jls")))
end
......
......@@ -5,7 +5,7 @@
#SBATCH --ntasks=10
#SBATCH --cpus-per-task=10
#SBATCH --partition=compute
#SBATCH --mem-per-cpu=12GB
#SBATCH --mem-per-cpu=16GB
#SBATCH --account=research-eemcs-insy
#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes.
......
......@@ -5,7 +5,7 @@
#SBATCH --ntasks=10
#SBATCH --cpus-per-task=10
#SBATCH --partition=compute
#SBATCH --mem-per-cpu=12GB
#SBATCH --mem-per-cpu=16GB
#SBATCH --account=research-eemcs-insy
#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment