diff --git a/experiments/experiment.jl b/experiments/experiment.jl index 41559715dc002eb3a06bf098bb4249a03c3ee619..1e1845ee3ade50980a04d88defdf2add71c7c844 100644 --- a/experiments/experiment.jl +++ b/experiments/experiment.jl @@ -6,7 +6,7 @@ Base.@kwdef struct Experiment save_name::String = replace(lowercase(dataname), " " => "_") output_path::String = DEFAULT_OUTPUT_PATH params_path::String = joinpath(output_path, "params") - use_pretrained::Bool = true + use_pretrained::Bool = !RETRAIN models::Union{Nothing,Dict} = nothing builder::Union{Nothing,MLJFlux.Builder} = nothing ð’Ÿx::Distribution = Normal() @@ -74,5 +74,6 @@ function run_experiment!(counterfactual_data::CounterfactualData, test_data::Cou test_data=test_data, kwargs... ) + println(exp.use_pretrained) return run_experiment!(exp) end \ No newline at end of file diff --git a/experiments/linearly_separable.jl b/experiments/linearly_separable.jl index f884bd206b59fffefde9ac83349408e862761a83..a21444d49cc89d37b318bdaf49eef6113fc610fe 100644 --- a/experiments/linearly_separable.jl +++ b/experiments/linearly_separable.jl @@ -3,4 +3,4 @@ counterfactual_data, test_data = train_test_split( load_blobs(n_obs; cluster_std=0.1, center_box=(-1.0 => 1.0)); test_size=TEST_SIZE ) -run_experiment!(counterfactual_data, test_data; dataname="Linearly Separable") \ No newline at end of file +run_experiment!(counterfactual_data, test_data; dataname="Linearly Separable", use_pretrained=false) \ No newline at end of file diff --git a/experiments/models/models.jl b/experiments/models/models.jl index f9d645bf88295b1a6c51ba5fb2fb42f794d5f05f..7ad1c36d3559d2e6bb57c39b22c2c550c558d062 100644 --- a/experiments/models/models.jl +++ b/experiments/models/models.jl @@ -21,7 +21,7 @@ function prepare_models(exp::Experiment) ) end @info "Training models." - model_dict = train_models(models, X, labels; coverage=exp.coverage) + model_dict = train_models(models, X, labels; cov=exp.coverage) else @info "Loading pre-trained models." model_dict = Serialization.deserialize(joinpath(pretrained_path(), "results/$(exp.save_name)_models.jls")) diff --git a/experiments/models/train_models.jl b/experiments/models/train_models.jl index 429dbe64a8bea57cbfbd4ec5f56a7e0d2a2b3df6..863584a48bd8570fd672666290f912f9abac958c 100644 --- a/experiments/models/train_models.jl +++ b/experiments/models/train_models.jl @@ -20,7 +20,7 @@ end Trains a model and returns a `ConformalModel` object. """ -function _train(model, X, y; cov=coverage, method=:simple_inductive, mod_name="model") +function _train(model, X, y; cov, method=:simple_inductive, mod_name="model") conf_model = conformal_model(model; method=method, coverage=cov) mach = machine(conf_model, X, y) @info "Begin training $mod_name." diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl index f2f84e444ad21bb588beb1f798b64d864d5ee171..846b3f496ed543c4ddc8b71f3ba9da995e8d8321 100644 --- a/experiments/setup_env.jl +++ b/experiments/setup_env.jl @@ -16,7 +16,7 @@ using ECCCo using Flux using JointEnergyModels using LazyArtifacts -using MLJBase: multiclass_f1score, accuracy, multiclass_precision, table +using MLJBase: multiclass_f1score, accuracy, multiclass_precision, table, machine, fit! using MLJEnsembles using MLJFlux using Serialization