Skip to content
Snippets Groups Projects
Commit 4bd5711d authored by pat-alt's avatar pat-alt
Browse files

more work on streamlining :cry:

parent f2ab5fa4
No related branches found
No related tags found
1 merge request!68Post rebuttal
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
/artifacts/ /artifacts/
/.quarto/ /.quarto/
/Manifest.toml /Manifest.toml
/replicated/ /results/
**/.CondaPkg **/.CondaPkg
/dev/rebuttal/www /dev/rebuttal/www
......
# General setup: # General setup:
include("$(pwd())/notebooks/setup.jl") include("$(pwd())/notebooks/setup.jl")
eval(setup_notebooks) eval(setup_notebooks)
output_path = "$(pwd())/replicated"
isdir(output_path) || mkdir(output_path)
@info "All results will be saved to $output_path."
params_path = "$(pwd())/replicated/params"
isdir(params_path) || mkdir(params_path)
@info "All parameter choices will be saved to $params_path."
test_size = 0.2 test_size = 0.2
# Constants: # Constants:
const DEFAULT_OUTPUT_PATH = "$(pwd())/results"
const RETRAIN = "retrain" ARGS ? true : false const RETRAIN = "retrain" ARGS ? true : false
# Artifacts: # Pre-trained models:
using LazyArtifacts function pretrained_path()
@warn "Models were pre-trained on `julia-1.8.5` and may not work on other versions." @info "Models were pre-trained on `julia-1.8.5` and may not work on other versions."
artifact_path = joinpath(artifact"results-paper-submission-1.8.5","results-paper-submission-1.8.5") return joinpath(artifact"results-paper-submission-1.8.5", "results-paper-submission-1.8.5")
pretrained_path = joinpath(artifact_path, "results") end
# Scripts: # Scripts:
include("data/data.jl") include("data/data.jl")
...@@ -28,8 +24,8 @@ Base.@kwdef struct Experiment ...@@ -28,8 +24,8 @@ Base.@kwdef struct Experiment
counterfactual_data::CounterfactualData counterfactual_data::CounterfactualData
test_data::CounterfactualData test_data::CounterfactualData
dataname::String = "dataset" dataname::String = "dataset"
output_path::String = output_path output_path::String = DEFAULT_OUTPUT_PATH
pretrained_path::String = pretrained_path params_path::String = joinpath(output_path, "params")
use_pretrained::Bool = true use_pretrained::Bool = true
models::Union{Nothing, Dict} = nothing models::Union{Nothing, Dict} = nothing
builder::Union{Nothing, MLJFlux.GenericBuilder} = nothing builder::Union{Nothing, MLJFlux.GenericBuilder} = nothing
...@@ -46,8 +42,13 @@ end ...@@ -46,8 +42,13 @@ end
Run the experiment specified by `exp`. Run the experiment specified by `exp`.
""" """
function run_experiment(exp::Experiment) function run_experiment(exp::Experiment)
# SETUP ---------- # SETUP ----------
@info "All results will be saved to $(exp.output_path)."
isdir(exp.output_path) || mkdir(exp.output_path)
@info "All parameter choices will be saved to $(exp.params_path)."
isdir(exp.params_path) || mkdir(exp.params_path)
# Data # Data
X, labels, n_obs, save_name, batch_size, sampler = prepare_data( X, labels, n_obs, save_name, batch_size, sampler = prepare_data(
counterfactual_data; counterfactual_data;
...@@ -75,7 +76,7 @@ function run_experiment(exp::Experiment) ...@@ -75,7 +76,7 @@ function run_experiment(exp::Experiment)
Serialization.serialize(joinpath(output_path, "$(save_name)_models.jls"), model_dict) Serialization.serialize(joinpath(output_path, "$(save_name)_models.jls"), model_dict)
else else
@info "Loading pre-trained models." @info "Loading pre-trained models."
model_dict = Serialization.deserialize(joinpath(pretrained_path, "$(save_name)_models.jls")) model_dict = Serialization.deserialize(joinpath(pretrained_path(), "$(save_name)_models.jls"))
end end
params = DataFrame( params = DataFrame(
......
...@@ -24,7 +24,6 @@ setup_notebooks = quote ...@@ -24,7 +24,6 @@ setup_notebooks = quote
using Flux using Flux
using Images using Images
using JointEnergyModels using JointEnergyModels
using LaplaceRedux: LaplaceApproximation
using LinearAlgebra using LinearAlgebra
using Markdown using Markdown
using MLDatasets using MLDatasets
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment