diff --git a/experiments/experiment.jl b/experiments/experiment.jl index 69db290c616bdf8d6183afdee62d82aba34078b6..b013d15679728ef9f0fbe030c8e3d0155b0198dc 100644 --- a/experiments/experiment.jl +++ b/experiments/experiment.jl @@ -125,21 +125,21 @@ end Overload the `run_experiment` function to allow for passing in `CounterfactualData` objects and other keyword arguments. """ -function run_experiment(counterfactual_data::CounterfactualData, test_data::CounterfactualData; kwargs...) +function run_experiment(counterfactual_data::CounterfactualData, test_data::CounterfactualData; save_output::Bool=true, kwargs...) # Parameters: exper = Experiment(; counterfactual_data=counterfactual_data, test_data=test_data, kwargs... ) - return run_experiment(exper) + return run_experiment(exper; save_output=save_output) end # Pre-trained models: function pretrained_path(exper::Experiment) if isfile(joinpath(DEFAULT_OUTPUT_PATH, "$(exper.save_name)_models.jls")) @info "Found local pre-trained models in $(DEFAULT_OUTPUT_PATH) and using those." - return exper.output_path + return DEFAULT_OUTPUT_PATH else @info "Using artifacts. Models were pre-trained on `julia-$(LATEST_VERSION)` and may not work on other versions." Pkg.Artifacts.download_artifact(ARTIFACT_HASH, ARTIFACT_TOML) diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl index 3abd82b6f82e6ce5b9bd33abf2c731b9c29f0b38..6bcadaab7f0c5b0e155e2ccfaebdf74fa245bac2 100644 --- a/experiments/grid_search.jl +++ b/experiments/grid_search.jl @@ -29,9 +29,9 @@ function grid_search( for tuning_params in grid outcome = run_experiment( counterfactual_data, test_data; + save_output=false, dataname=dataname, output_path=grid_search_path, - save_output=false, tuning_params..., kwargs..., ) diff --git a/experiments/models/models.jl b/experiments/models/models.jl index 822128b27a1512424fe638589c3525b22804a502..0ce02068b8a537edc083f445d9b350cd54fa640b 100644 --- a/experiments/models/models.jl +++ b/experiments/models/models.jl @@ -76,8 +76,8 @@ function prepare_models(exper::Experiment) # Save models: if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) - @info "Saving models to $(joinpath(exper.output_path, "$(exper.save_name)_models.jls"))." - Serialization.serialize(joinpath(exper.output_path, "$(exper.save_name)_models.jls"), model_dict) + @info "Saving models to $(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"))." + Serialization.serialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"), model_dict) end return model_dict