From c381b5049915829317db81f8785a74416fe15399 Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Sun, 17 Sep 2023 10:37:48 +0200 Subject: [PATCH] MNIST still crashing :cry: --- experiments/experiment.jl | 6 +++--- experiments/grid_search.jl | 2 +- experiments/models/models.jl | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/experiments/experiment.jl b/experiments/experiment.jl index 69db290c..b013d156 100644 --- a/experiments/experiment.jl +++ b/experiments/experiment.jl @@ -125,21 +125,21 @@ end Overload the `run_experiment` function to allow for passing in `CounterfactualData` objects and other keyword arguments. """ -function run_experiment(counterfactual_data::CounterfactualData, test_data::CounterfactualData; kwargs...) +function run_experiment(counterfactual_data::CounterfactualData, test_data::CounterfactualData; save_output::Bool=true, kwargs...) # Parameters: exper = Experiment(; counterfactual_data=counterfactual_data, test_data=test_data, kwargs... ) - return run_experiment(exper) + return run_experiment(exper; save_output=save_output) end # Pre-trained models: function pretrained_path(exper::Experiment) if isfile(joinpath(DEFAULT_OUTPUT_PATH, "$(exper.save_name)_models.jls")) @info "Found local pre-trained models in $(DEFAULT_OUTPUT_PATH) and using those." - return exper.output_path + return DEFAULT_OUTPUT_PATH else @info "Using artifacts. Models were pre-trained on `julia-$(LATEST_VERSION)` and may not work on other versions." Pkg.Artifacts.download_artifact(ARTIFACT_HASH, ARTIFACT_TOML) diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl index 3abd82b6..6bcadaab 100644 --- a/experiments/grid_search.jl +++ b/experiments/grid_search.jl @@ -29,9 +29,9 @@ function grid_search( for tuning_params in grid outcome = run_experiment( counterfactual_data, test_data; + save_output=false, dataname=dataname, output_path=grid_search_path, - save_output=false, tuning_params..., kwargs..., ) diff --git a/experiments/models/models.jl b/experiments/models/models.jl index 822128b2..0ce02068 100644 --- a/experiments/models/models.jl +++ b/experiments/models/models.jl @@ -76,8 +76,8 @@ function prepare_models(exper::Experiment) # Save models: if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) - @info "Saving models to $(joinpath(exper.output_path, "$(exper.save_name)_models.jls"))." - Serialization.serialize(joinpath(exper.output_path, "$(exper.save_name)_models.jls"), model_dict) + @info "Saving models to $(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"))." + Serialization.serialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"), model_dict) end return model_dict -- GitLab