Skip to content
Snippets Groups Projects

Post rebuttal

Merged Imported Patrick Altmeyer requested to merge post-rebuttal into main
3 files
+ 17
17
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 16
15
# General setup:
include("$(pwd())/notebooks/setup.jl")
eval(setup_notebooks)
output_path = "$(pwd())/replicated"
isdir(output_path) || mkdir(output_path)
@info "All results will be saved to $output_path."
params_path = "$(pwd())/replicated/params"
isdir(params_path) || mkdir(params_path)
@info "All parameter choices will be saved to $params_path."
test_size = 0.2
# Constants:
const DEFAULT_OUTPUT_PATH = "$(pwd())/results"
const RETRAIN = "retrain" ARGS ? true : false
# Artifacts:
using LazyArtifacts
@warn "Models were pre-trained on `julia-1.8.5` and may not work on other versions."
artifact_path = joinpath(artifact"results-paper-submission-1.8.5","results-paper-submission-1.8.5")
pretrained_path = joinpath(artifact_path, "results")
# Pre-trained models:
function pretrained_path()
@info "Models were pre-trained on `julia-1.8.5` and may not work on other versions."
return joinpath(artifact"results-paper-submission-1.8.5", "results-paper-submission-1.8.5")
end
# Scripts:
include("data/data.jl")
@@ -28,8 +24,8 @@ Base.@kwdef struct Experiment
counterfactual_data::CounterfactualData
test_data::CounterfactualData
dataname::String = "dataset"
output_path::String = output_path
pretrained_path::String = pretrained_path
output_path::String = DEFAULT_OUTPUT_PATH
params_path::String = joinpath(output_path, "params")
use_pretrained::Bool = true
models::Union{Nothing, Dict} = nothing
builder::Union{Nothing, MLJFlux.GenericBuilder} = nothing
@@ -46,8 +42,13 @@ end
Run the experiment specified by `exp`.
"""
function run_experiment(exp::Experiment)
# SETUP ----------
@info "All results will be saved to $(exp.output_path)."
isdir(exp.output_path) || mkdir(exp.output_path)
@info "All parameter choices will be saved to $(exp.params_path)."
isdir(exp.params_path) || mkdir(exp.params_path)
# Data
X, labels, n_obs, save_name, batch_size, sampler = prepare_data(
counterfactual_data;
@@ -75,7 +76,7 @@ function run_experiment(exp::Experiment)
Serialization.serialize(joinpath(output_path, "$(save_name)_models.jls"), model_dict)
else
@info "Loading pre-trained models."
model_dict = Serialization.deserialize(joinpath(pretrained_path, "$(save_name)_models.jls"))
model_dict = Serialization.deserialize(joinpath(pretrained_path(), "$(save_name)_models.jls"))
end
params = DataFrame(
Loading