Skip to content
Snippets Groups Projects
Commit 232965b6 authored by pat-alt's avatar pat-alt
Browse files

set up for grid search

parent a4364637
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
......@@ -80,10 +80,12 @@ Run the experiment specified by `exper`.
function run_experiment(exper::Experiment; save_output::Bool=true, only_models::Bool=ONLY_MODELS)
# Setup
@info "All results will be saved to $(exper.output_path)."
isdir(exper.output_path) || mkdir(exper.output_path)
@info "All parameter choices will be saved to $(exper.params_path)."
isdir(exper.params_path) || mkdir(exper.params_path)
if save_output && !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
@info "All results will be saved to $(exper.output_path)."
isdir(exper.output_path) || mkdir(exper.output_path)
@info "All parameter choices will be saved to $(exper.params_path)."
isdir(exper.params_path) || mkdir(exper.params_path)
end
outcome = ExperimentOutcome(exper, nothing, nothing, nothing)
# Models
......@@ -126,8 +128,8 @@ end
# Pre-trained models:
function pretrained_path(exper::Experiment)
if isfile(joinpath(exper.output_path, "$(exper.save_name)_models.jls"))
@info "Found local pre-trained models in $(exper.output_path) and using those."
if isfile(joinpath(DEFAULT_OUTPUT_PATH, "$(exper.save_name)_models.jls"))
@info "Found local pre-trained models in $(DEFAULT_OUTPUT_PATH) and using those."
return exper.output_path
else
@info "Using artifacts. Models were pre-trained on `julia-$(LATEST_VERSION)` and may not work on other versions."
......
"""
grid_search(
couterfactual_data::CounterfactualData,
test_data::CounerfactualData;
dataname::String,
tuning_params::NamedTuple,
kwargs...,
)
Perform a grid search over the hyperparameters specified by `tuning_params`. Experiments will be run for each combination of hyperparameters. Other keyword arguments are passed to `run_experiment` and fixed for all experiments.
"""
function grid_search(
couterfactual_data::CounterfactualData,
test_data::CounerfactualData;
dataname::String,
tuning_params::NamedTuple,
kwargs...,
)
# Output path:
grid_search_path = mkpath(joinpath(DEFAULT_OUTPUT_PATH, "grid_search"))
# Grid setup:
tuning_params = [Pair.(k, vals) for (k, vals) in pairs(tuning_params)]
grid = Iterators.product(tuning_params...)
outcomes = Dict{Any,Any}()
# Search:
for tuning_params in grid
outcome = run_experiment(
counterfactual_data, test_data;
dataname=dataname,
output_path=grid_search_path,
save_output=false,
tuning_params...,
kwargs...,
)
outcomes[tuning_params] = outcome
end
# Save:
Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_")).jls"), outcomes)
end
\ No newline at end of file
#!/bin/bash
#SBATCH --job-name="Grid-search Synthetic (ECCCo)"
#SBATCH --time=03:00:00
#SBATCH --ntasks=1000
#SBATCH --cpus-per-task=1
#SBATCH --partition=compute
#SBATCH --mem-per-cpu=4GB
#SBATCH --account=research-eemcs-insy
#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes.
module load 2023r1 openmpi
srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable,moons,circles output_path=results mpi grid_search > experiments/synthetic.log
# Data:
n_obs = Int(1000 / (1.0 - TEST_SIZE))
counterfactual_data, test_data = train_test_split(
load_blobs(n_obs; cluster_std=0.1, center_box=(-1.0 => 1.0));
test_size=TEST_SIZE
)
run_experiment(
counterfactual_data, test_data;
dataname="Linearly Separable",
# Tuning parameters:
tuning_params = (
nsamples=[10, 50, 100],
niter_eccco=[20, 50, 100],
Λ=[
[0.1, 0.1, 0.1],
[0.1, 0.2, 0.2],
[0.1, 0.5, 0.5],
]
)
# Parameter choices:
params = (
nsamples=100,
niter_eccco=100,
Λ=[0.1, 0.2, 0.2],
)
\ No newline at end of file
)
if !GRID_SEARCH
run_experiment(
counterfactual_data, test_data;
dataname="Linearly Separable",
params...
)
else
grid_search(
counterfactual_data, test_data;
dataname="Linearly Separable",
tuning_params=tuning_params,
)
end
\ No newline at end of file
......@@ -34,6 +34,7 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # avoid command prompt and j
# Scripts:
include("experiment.jl")
include("grid_search.jl")
include("data/data.jl")
include("models/models.jl")
include("benchmarking/benchmarking.jl")
......@@ -121,6 +122,7 @@ const CE_MEASURES = [
"Test set proportion."
const TEST_SIZE = 0.2
"Boolean flag to check if upload was specified."
const UPLOAD = "upload" ARGS
n_ind_specified = false
......@@ -135,4 +137,7 @@ end
const N_IND = n_individuals
"Boolean flag to check if number of individuals was specified."
const N_IND_SPECIFIED = n_ind_specified
\ No newline at end of file
const N_IND_SPECIFIED = n_ind_specified
"Boolean flag to check if grid search was specified."
const GRID_SEARCH = "grid_search" ARGS
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment