From 8e006ea09873bb329c99ec712ecbf7c9ae75062f Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Tue, 22 Aug 2023 17:58:50 +0200 Subject: [PATCH] basic flow seems to be working now --- experiments/experiment.jl | 3 ++- experiments/linearly_separable.jl | 2 +- experiments/models/models.jl | 2 +- experiments/models/train_models.jl | 2 +- experiments/setup_env.jl | 2 +- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/experiments/experiment.jl b/experiments/experiment.jl index 41559715..1e1845ee 100644 --- a/experiments/experiment.jl +++ b/experiments/experiment.jl @@ -6,7 +6,7 @@ Base.@kwdef struct Experiment save_name::String = replace(lowercase(dataname), " " => "_") output_path::String = DEFAULT_OUTPUT_PATH params_path::String = joinpath(output_path, "params") - use_pretrained::Bool = true + use_pretrained::Bool = !RETRAIN models::Union{Nothing,Dict} = nothing builder::Union{Nothing,MLJFlux.Builder} = nothing ð’Ÿx::Distribution = Normal() @@ -74,5 +74,6 @@ function run_experiment!(counterfactual_data::CounterfactualData, test_data::Cou test_data=test_data, kwargs... ) + println(exp.use_pretrained) return run_experiment!(exp) end \ No newline at end of file diff --git a/experiments/linearly_separable.jl b/experiments/linearly_separable.jl index f884bd20..a21444d4 100644 --- a/experiments/linearly_separable.jl +++ b/experiments/linearly_separable.jl @@ -3,4 +3,4 @@ counterfactual_data, test_data = train_test_split( load_blobs(n_obs; cluster_std=0.1, center_box=(-1.0 => 1.0)); test_size=TEST_SIZE ) -run_experiment!(counterfactual_data, test_data; dataname="Linearly Separable") \ No newline at end of file +run_experiment!(counterfactual_data, test_data; dataname="Linearly Separable", use_pretrained=false) \ No newline at end of file diff --git a/experiments/models/models.jl b/experiments/models/models.jl index f9d645bf..7ad1c36d 100644 --- a/experiments/models/models.jl +++ b/experiments/models/models.jl @@ -21,7 +21,7 @@ function prepare_models(exp::Experiment) ) end @info "Training models." - model_dict = train_models(models, X, labels; coverage=exp.coverage) + model_dict = train_models(models, X, labels; cov=exp.coverage) else @info "Loading pre-trained models." model_dict = Serialization.deserialize(joinpath(pretrained_path(), "results/$(exp.save_name)_models.jls")) diff --git a/experiments/models/train_models.jl b/experiments/models/train_models.jl index 429dbe64..863584a4 100644 --- a/experiments/models/train_models.jl +++ b/experiments/models/train_models.jl @@ -20,7 +20,7 @@ end Trains a model and returns a `ConformalModel` object. """ -function _train(model, X, y; cov=coverage, method=:simple_inductive, mod_name="model") +function _train(model, X, y; cov, method=:simple_inductive, mod_name="model") conf_model = conformal_model(model; method=method, coverage=cov) mach = machine(conf_model, X, y) @info "Begin training $mod_name." diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl index f2f84e44..846b3f49 100644 --- a/experiments/setup_env.jl +++ b/experiments/setup_env.jl @@ -16,7 +16,7 @@ using ECCCo using Flux using JointEnergyModels using LazyArtifacts -using MLJBase: multiclass_f1score, accuracy, multiclass_precision, table +using MLJBase: multiclass_f1score, accuracy, multiclass_precision, table, machine, fit! using MLJEnsembles using MLJFlux using Serialization -- GitLab