diff --git a/.gitignore b/.gitignore index cf596d488a9552a4b5663a86bcf0015232468a91..9576191e56b0c3ea0e24b58ee7bb24b90256676b 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ /artifacts/ /.quarto/ /Manifest.toml -/replicated/ +/results/ **/.CondaPkg /dev/rebuttal/www diff --git a/experiments/setup.jl b/experiments/setup.jl index a167bf9d40a0b339c3a968a5283a9de32f3926ab..e50984a50e23caf534859c064f5e2f69bd24e51c 100644 --- a/experiments/setup.jl +++ b/experiments/setup.jl @@ -1,22 +1,18 @@ # General setup: include("$(pwd())/notebooks/setup.jl") eval(setup_notebooks) -output_path = "$(pwd())/replicated" -isdir(output_path) || mkdir(output_path) -@info "All results will be saved to $output_path." -params_path = "$(pwd())/replicated/params" -isdir(params_path) || mkdir(params_path) -@info "All parameter choices will be saved to $params_path." + test_size = 0.2 # Constants: +const DEFAULT_OUTPUT_PATH = "$(pwd())/results" const RETRAIN = "retrain" ∈ ARGS ? true : false -# Artifacts: -using LazyArtifacts -@warn "Models were pre-trained on `julia-1.8.5` and may not work on other versions." -artifact_path = joinpath(artifact"results-paper-submission-1.8.5","results-paper-submission-1.8.5") -pretrained_path = joinpath(artifact_path, "results") +# Pre-trained models: +function pretrained_path() + @info "Models were pre-trained on `julia-1.8.5` and may not work on other versions." + return joinpath(artifact"results-paper-submission-1.8.5", "results-paper-submission-1.8.5") +end # Scripts: include("data/data.jl") @@ -28,8 +24,8 @@ Base.@kwdef struct Experiment counterfactual_data::CounterfactualData test_data::CounterfactualData dataname::String = "dataset" - output_path::String = output_path - pretrained_path::String = pretrained_path + output_path::String = DEFAULT_OUTPUT_PATH + params_path::String = joinpath(output_path, "params") use_pretrained::Bool = true models::Union{Nothing, Dict} = nothing builder::Union{Nothing, MLJFlux.GenericBuilder} = nothing @@ -46,8 +42,13 @@ end Run the experiment specified by `exp`. """ function run_experiment(exp::Experiment) - + # SETUP ---------- + @info "All results will be saved to $(exp.output_path)." + isdir(exp.output_path) || mkdir(exp.output_path) + @info "All parameter choices will be saved to $(exp.params_path)." + isdir(exp.params_path) || mkdir(exp.params_path) + # Data X, labels, n_obs, save_name, batch_size, sampler = prepare_data( counterfactual_data; @@ -75,7 +76,7 @@ function run_experiment(exp::Experiment) Serialization.serialize(joinpath(output_path, "$(save_name)_models.jls"), model_dict) else @info "Loading pre-trained models." - model_dict = Serialization.deserialize(joinpath(pretrained_path, "$(save_name)_models.jls")) + model_dict = Serialization.deserialize(joinpath(pretrained_path(), "$(save_name)_models.jls")) end params = DataFrame( diff --git a/notebooks/setup.jl b/notebooks/setup.jl index 058083dafb2d3417de36f51d6386ae539a0365d7..c5c215d3dbccee530df97104fafd6fb9b02ce663 100644 --- a/notebooks/setup.jl +++ b/notebooks/setup.jl @@ -24,7 +24,6 @@ setup_notebooks = quote using Flux using Images using JointEnergyModels - using LaplaceRedux: LaplaceApproximation using LinearAlgebra using Markdown using MLDatasets