Skip to content
Snippets Groups Projects
Commit 4bd5711d authored by pat-alt's avatar pat-alt
Browse files

more work on streamlining :cry:

parent f2ab5fa4
No related branches found
No related tags found
1 merge request!68Post rebuttal
......@@ -2,7 +2,7 @@
/artifacts/
/.quarto/
/Manifest.toml
/replicated/
/results/
**/.CondaPkg
/dev/rebuttal/www
......
# General setup:
include("$(pwd())/notebooks/setup.jl")
eval(setup_notebooks)
output_path = "$(pwd())/replicated"
isdir(output_path) || mkdir(output_path)
@info "All results will be saved to $output_path."
params_path = "$(pwd())/replicated/params"
isdir(params_path) || mkdir(params_path)
@info "All parameter choices will be saved to $params_path."
test_size = 0.2
# Constants:
const DEFAULT_OUTPUT_PATH = "$(pwd())/results"
const RETRAIN = "retrain" ARGS ? true : false
# Artifacts:
using LazyArtifacts
@warn "Models were pre-trained on `julia-1.8.5` and may not work on other versions."
artifact_path = joinpath(artifact"results-paper-submission-1.8.5","results-paper-submission-1.8.5")
pretrained_path = joinpath(artifact_path, "results")
# Pre-trained models:
function pretrained_path()
@info "Models were pre-trained on `julia-1.8.5` and may not work on other versions."
return joinpath(artifact"results-paper-submission-1.8.5", "results-paper-submission-1.8.5")
end
# Scripts:
include("data/data.jl")
......@@ -28,8 +24,8 @@ Base.@kwdef struct Experiment
counterfactual_data::CounterfactualData
test_data::CounterfactualData
dataname::String = "dataset"
output_path::String = output_path
pretrained_path::String = pretrained_path
output_path::String = DEFAULT_OUTPUT_PATH
params_path::String = joinpath(output_path, "params")
use_pretrained::Bool = true
models::Union{Nothing, Dict} = nothing
builder::Union{Nothing, MLJFlux.GenericBuilder} = nothing
......@@ -46,8 +42,13 @@ end
Run the experiment specified by `exp`.
"""
function run_experiment(exp::Experiment)
# SETUP ----------
@info "All results will be saved to $(exp.output_path)."
isdir(exp.output_path) || mkdir(exp.output_path)
@info "All parameter choices will be saved to $(exp.params_path)."
isdir(exp.params_path) || mkdir(exp.params_path)
# Data
X, labels, n_obs, save_name, batch_size, sampler = prepare_data(
counterfactual_data;
......@@ -75,7 +76,7 @@ function run_experiment(exp::Experiment)
Serialization.serialize(joinpath(output_path, "$(save_name)_models.jls"), model_dict)
else
@info "Loading pre-trained models."
model_dict = Serialization.deserialize(joinpath(pretrained_path, "$(save_name)_models.jls"))
model_dict = Serialization.deserialize(joinpath(pretrained_path(), "$(save_name)_models.jls"))
end
params = DataFrame(
......
......@@ -24,7 +24,6 @@ setup_notebooks = quote
using Flux
using Images
using JointEnergyModels
using LaplaceRedux: LaplaceApproximation
using LinearAlgebra
using Markdown
using MLDatasets
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment