From 4bd5711db8b8389e4552c2ebdfed710decb9f352 Mon Sep 17 00:00:00 2001 From: pat-alt <altmeyerpat@gmail.com> Date: Tue, 22 Aug 2023 11:55:37 +0200 Subject: [PATCH] more work on streamlining :cry: --- .gitignore | 2 +- experiments/setup.jl | 31 ++++++++++++++++--------------- notebooks/setup.jl | 1 - 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index cf596d48..9576191e 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ /artifacts/ /.quarto/ /Manifest.toml -/replicated/ +/results/ **/.CondaPkg /dev/rebuttal/www diff --git a/experiments/setup.jl b/experiments/setup.jl index a167bf9d..e50984a5 100644 --- a/experiments/setup.jl +++ b/experiments/setup.jl @@ -1,22 +1,18 @@ # General setup: include("$(pwd())/notebooks/setup.jl") eval(setup_notebooks) -output_path = "$(pwd())/replicated" -isdir(output_path) || mkdir(output_path) -@info "All results will be saved to $output_path." -params_path = "$(pwd())/replicated/params" -isdir(params_path) || mkdir(params_path) -@info "All parameter choices will be saved to $params_path." + test_size = 0.2 # Constants: +const DEFAULT_OUTPUT_PATH = "$(pwd())/results" const RETRAIN = "retrain" ∈ ARGS ? true : false -# Artifacts: -using LazyArtifacts -@warn "Models were pre-trained on `julia-1.8.5` and may not work on other versions." -artifact_path = joinpath(artifact"results-paper-submission-1.8.5","results-paper-submission-1.8.5") -pretrained_path = joinpath(artifact_path, "results") +# Pre-trained models: +function pretrained_path() + @info "Models were pre-trained on `julia-1.8.5` and may not work on other versions." + return joinpath(artifact"results-paper-submission-1.8.5", "results-paper-submission-1.8.5") +end # Scripts: include("data/data.jl") @@ -28,8 +24,8 @@ Base.@kwdef struct Experiment counterfactual_data::CounterfactualData test_data::CounterfactualData dataname::String = "dataset" - output_path::String = output_path - pretrained_path::String = pretrained_path + output_path::String = DEFAULT_OUTPUT_PATH + params_path::String = joinpath(output_path, "params") use_pretrained::Bool = true models::Union{Nothing, Dict} = nothing builder::Union{Nothing, MLJFlux.GenericBuilder} = nothing @@ -46,8 +42,13 @@ end Run the experiment specified by `exp`. """ function run_experiment(exp::Experiment) - + # SETUP ---------- + @info "All results will be saved to $(exp.output_path)." + isdir(exp.output_path) || mkdir(exp.output_path) + @info "All parameter choices will be saved to $(exp.params_path)." + isdir(exp.params_path) || mkdir(exp.params_path) + # Data X, labels, n_obs, save_name, batch_size, sampler = prepare_data( counterfactual_data; @@ -75,7 +76,7 @@ function run_experiment(exp::Experiment) Serialization.serialize(joinpath(output_path, "$(save_name)_models.jls"), model_dict) else @info "Loading pre-trained models." - model_dict = Serialization.deserialize(joinpath(pretrained_path, "$(save_name)_models.jls")) + model_dict = Serialization.deserialize(joinpath(pretrained_path(), "$(save_name)_models.jls")) end params = DataFrame( diff --git a/notebooks/setup.jl b/notebooks/setup.jl index 058083da..c5c215d3 100644 --- a/notebooks/setup.jl +++ b/notebooks/setup.jl @@ -24,7 +24,6 @@ setup_notebooks = quote using Flux using Images using JointEnergyModels - using LaplaceRedux: LaplaceApproximation using LinearAlgebra using Markdown using MLDatasets -- GitLab