From 8e006ea09873bb329c99ec712ecbf7c9ae75062f Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Tue, 22 Aug 2023 17:58:50 +0200
Subject: [PATCH] basic flow seems to be working now

---
 experiments/experiment.jl          | 3 ++-
 experiments/linearly_separable.jl  | 2 +-
 experiments/models/models.jl       | 2 +-
 experiments/models/train_models.jl | 2 +-
 experiments/setup_env.jl           | 2 +-
 5 files changed, 6 insertions(+), 5 deletions(-)

diff --git a/experiments/experiment.jl b/experiments/experiment.jl
index 41559715..1e1845ee 100644
--- a/experiments/experiment.jl
+++ b/experiments/experiment.jl
@@ -6,7 +6,7 @@ Base.@kwdef struct Experiment
     save_name::String = replace(lowercase(dataname), " " => "_")
     output_path::String = DEFAULT_OUTPUT_PATH
     params_path::String = joinpath(output_path, "params")
-    use_pretrained::Bool = true
+    use_pretrained::Bool = !RETRAIN
     models::Union{Nothing,Dict} = nothing
     builder::Union{Nothing,MLJFlux.Builder} = nothing
     𝒟x::Distribution = Normal()
@@ -74,5 +74,6 @@ function run_experiment!(counterfactual_data::CounterfactualData, test_data::Cou
         test_data=test_data,
         kwargs...
     )
+    println(exp.use_pretrained)
     return run_experiment!(exp)
 end
\ No newline at end of file
diff --git a/experiments/linearly_separable.jl b/experiments/linearly_separable.jl
index f884bd20..a21444d4 100644
--- a/experiments/linearly_separable.jl
+++ b/experiments/linearly_separable.jl
@@ -3,4 +3,4 @@ counterfactual_data, test_data = train_test_split(
     load_blobs(n_obs; cluster_std=0.1, center_box=(-1.0 => 1.0));
     test_size=TEST_SIZE
 )
-run_experiment!(counterfactual_data, test_data; dataname="Linearly Separable")
\ No newline at end of file
+run_experiment!(counterfactual_data, test_data; dataname="Linearly Separable", use_pretrained=false)
\ No newline at end of file
diff --git a/experiments/models/models.jl b/experiments/models/models.jl
index f9d645bf..7ad1c36d 100644
--- a/experiments/models/models.jl
+++ b/experiments/models/models.jl
@@ -21,7 +21,7 @@ function prepare_models(exp::Experiment)
             )
         end
         @info "Training models."
-        model_dict = train_models(models, X, labels; coverage=exp.coverage)
+        model_dict = train_models(models, X, labels; cov=exp.coverage)
     else
         @info "Loading pre-trained models."
         model_dict = Serialization.deserialize(joinpath(pretrained_path(), "results/$(exp.save_name)_models.jls"))
diff --git a/experiments/models/train_models.jl b/experiments/models/train_models.jl
index 429dbe64..863584a4 100644
--- a/experiments/models/train_models.jl
+++ b/experiments/models/train_models.jl
@@ -20,7 +20,7 @@ end
 
 Trains a model and returns a `ConformalModel` object.
 """
-function _train(model, X, y; cov=coverage, method=:simple_inductive, mod_name="model")
+function _train(model, X, y; cov, method=:simple_inductive, mod_name="model")
     conf_model = conformal_model(model; method=method, coverage=cov)
     mach = machine(conf_model, X, y)
     @info "Begin training $mod_name."
diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl
index f2f84e44..846b3f49 100644
--- a/experiments/setup_env.jl
+++ b/experiments/setup_env.jl
@@ -16,7 +16,7 @@ using ECCCo
 using Flux
 using JointEnergyModels
 using LazyArtifacts
-using MLJBase: multiclass_f1score, accuracy, multiclass_precision, table
+using MLJBase: multiclass_f1score, accuracy, multiclass_precision, table, machine, fit!
 using MLJEnsembles
 using MLJFlux
 using Serialization
-- 
GitLab