From c381b5049915829317db81f8785a74416fe15399 Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Sun, 17 Sep 2023 10:37:48 +0200
Subject: [PATCH] MNIST still crashing :cry:

---
 experiments/experiment.jl    | 6 +++---
 experiments/grid_search.jl   | 2 +-
 experiments/models/models.jl | 4 ++--
 3 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/experiments/experiment.jl b/experiments/experiment.jl
index 69db290c..b013d156 100644
--- a/experiments/experiment.jl
+++ b/experiments/experiment.jl
@@ -125,21 +125,21 @@ end
 
 Overload the `run_experiment` function to allow for passing in `CounterfactualData` objects and other keyword arguments.
 """
-function run_experiment(counterfactual_data::CounterfactualData, test_data::CounterfactualData; kwargs...)
+function run_experiment(counterfactual_data::CounterfactualData, test_data::CounterfactualData; save_output::Bool=true, kwargs...)
     # Parameters:
     exper = Experiment(;
         counterfactual_data=counterfactual_data,
         test_data=test_data,
         kwargs...
     )
-    return run_experiment(exper)
+    return run_experiment(exper; save_output=save_output)
 end
 
 # Pre-trained models:
 function pretrained_path(exper::Experiment)
     if isfile(joinpath(DEFAULT_OUTPUT_PATH, "$(exper.save_name)_models.jls"))
         @info "Found local pre-trained models in $(DEFAULT_OUTPUT_PATH) and using those."
-        return exper.output_path
+        return DEFAULT_OUTPUT_PATH
     else
         @info "Using artifacts. Models were pre-trained on `julia-$(LATEST_VERSION)` and may not work on other versions."
         Pkg.Artifacts.download_artifact(ARTIFACT_HASH, ARTIFACT_TOML)
diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl
index 3abd82b6..6bcadaab 100644
--- a/experiments/grid_search.jl
+++ b/experiments/grid_search.jl
@@ -29,9 +29,9 @@ function grid_search(
     for tuning_params in grid
         outcome = run_experiment(
             counterfactual_data, test_data;
+            save_output=false,
             dataname=dataname,
             output_path=grid_search_path,
-            save_output=false,
             tuning_params...,
             kwargs...,
         )
diff --git a/experiments/models/models.jl b/experiments/models/models.jl
index 822128b2..0ce02068 100644
--- a/experiments/models/models.jl
+++ b/experiments/models/models.jl
@@ -76,8 +76,8 @@ function prepare_models(exper::Experiment)
 
     # Save models:
     if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
-        @info "Saving models to $(joinpath(exper.output_path, "$(exper.save_name)_models.jls"))."
-        Serialization.serialize(joinpath(exper.output_path, "$(exper.save_name)_models.jls"), model_dict)
+        @info "Saving models to $(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"))."
+        Serialization.serialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"), model_dict)
     end
 
     return model_dict
-- 
GitLab