From 4bd5711db8b8389e4552c2ebdfed710decb9f352 Mon Sep 17 00:00:00 2001
From: pat-alt <altmeyerpat@gmail.com>
Date: Tue, 22 Aug 2023 11:55:37 +0200
Subject: [PATCH] more work on streamlining :cry:

---
 .gitignore           |  2 +-
 experiments/setup.jl | 31 ++++++++++++++++---------------
 notebooks/setup.jl   |  1 -
 3 files changed, 17 insertions(+), 17 deletions(-)

diff --git a/.gitignore b/.gitignore
index cf596d48..9576191e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,7 +2,7 @@
 /artifacts/
 /.quarto/
 /Manifest.toml
-/replicated/
+/results/
 **/.CondaPkg
 /dev/rebuttal/www
 
diff --git a/experiments/setup.jl b/experiments/setup.jl
index a167bf9d..e50984a5 100644
--- a/experiments/setup.jl
+++ b/experiments/setup.jl
@@ -1,22 +1,18 @@
 # General setup:
 include("$(pwd())/notebooks/setup.jl")
 eval(setup_notebooks)
-output_path = "$(pwd())/replicated"
-isdir(output_path) || mkdir(output_path)
-@info "All results will be saved to $output_path."
-params_path = "$(pwd())/replicated/params"
-isdir(params_path) || mkdir(params_path)
-@info "All parameter choices will be saved to $params_path."
+
 test_size = 0.2
 
 # Constants:
+const DEFAULT_OUTPUT_PATH = "$(pwd())/results"
 const RETRAIN = "retrain" ∈ ARGS ? true : false
 
-# Artifacts:
-using LazyArtifacts
-@warn "Models were pre-trained on `julia-1.8.5` and may not work on other versions."
-artifact_path = joinpath(artifact"results-paper-submission-1.8.5","results-paper-submission-1.8.5")
-pretrained_path = joinpath(artifact_path, "results")
+# Pre-trained models:
+function pretrained_path()
+    @info "Models were pre-trained on `julia-1.8.5` and may not work on other versions."
+    return joinpath(artifact"results-paper-submission-1.8.5", "results-paper-submission-1.8.5")
+end
 
 # Scripts:
 include("data/data.jl")
@@ -28,8 +24,8 @@ Base.@kwdef struct Experiment
     counterfactual_data::CounterfactualData
     test_data::CounterfactualData
     dataname::String = "dataset"
-    output_path::String = output_path
-    pretrained_path::String = pretrained_path
+    output_path::String = DEFAULT_OUTPUT_PATH
+    params_path::String = joinpath(output_path, "params")
     use_pretrained::Bool = true
     models::Union{Nothing, Dict} = nothing
     builder::Union{Nothing, MLJFlux.GenericBuilder} = nothing
@@ -46,8 +42,13 @@ end
 Run the experiment specified by `exp`.
 """
 function run_experiment(exp::Experiment)
-
+    
     # SETUP ----------
+    @info "All results will be saved to $(exp.output_path)."
+    isdir(exp.output_path) || mkdir(exp.output_path)
+    @info "All parameter choices will be saved to $(exp.params_path)."
+    isdir(exp.params_path) || mkdir(exp.params_path)
+
     # Data
     X, labels, n_obs, save_name, batch_size, sampler = prepare_data(
         counterfactual_data;
@@ -75,7 +76,7 @@ function run_experiment(exp::Experiment)
         Serialization.serialize(joinpath(output_path, "$(save_name)_models.jls"), model_dict)
     else
         @info "Loading pre-trained models."
-        model_dict = Serialization.deserialize(joinpath(pretrained_path, "$(save_name)_models.jls"))
+        model_dict = Serialization.deserialize(joinpath(pretrained_path(), "$(save_name)_models.jls"))
     end
 
     params = DataFrame(
diff --git a/notebooks/setup.jl b/notebooks/setup.jl
index 058083da..c5c215d3 100644
--- a/notebooks/setup.jl
+++ b/notebooks/setup.jl
@@ -24,7 +24,6 @@ setup_notebooks = quote
     using Flux
     using Images
     using JointEnergyModels
-    using LaplaceRedux: LaplaceApproximation
     using LinearAlgebra
     using Markdown
     using MLDatasets
-- 
GitLab