From d312a6bfe5e753a17f276b9a7c8a5fc1cc9c7da5 Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Sun, 17 Sep 2023 20:44:09 +0200 Subject: [PATCH] some stuff --- experiments/Manifest.toml | 6 ++++-- experiments/experiment.jl | 8 ++++---- experiments/grid_search.jl | 4 +++- experiments/jobscripts/tuning/generators/synthetic.sh | 2 +- experiments/models/models.jl | 4 ++-- 5 files changed, 14 insertions(+), 10 deletions(-) diff --git a/experiments/Manifest.toml b/experiments/Manifest.toml index 244c5a2f..a82fd00b 100644 --- a/experiments/Manifest.toml +++ b/experiments/Manifest.toml @@ -444,9 +444,11 @@ version = "0.6.2" [[deps.CounterfactualExplanations]] deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "PrecompileTools", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "UUIDs", "cuDNN"] -git-tree-sha1 = "04acac995c56d1609277848e3bdd035e89b712c2" +git-tree-sha1 = "108bf1c3a0cefdcc53bda5482ea1aa6d8ee86382" +repo-rev = "main" +repo-url = "https://github.com/JuliaTrustworthyAI/CounterfactualExplanations.jl.git" uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" -version = "0.1.21" +version = "0.1.23" [deps.CounterfactualExplanations.extensions] MPIExt = "MPI" diff --git a/experiments/experiment.jl b/experiments/experiment.jl index baf73b49..74755b7b 100644 --- a/experiments/experiment.jl +++ b/experiments/experiment.jl @@ -54,8 +54,8 @@ end Train the models specified by `exper` and store them in `outcome`. """ -function train_models!(outcome::ExperimentOutcome, exper::Experiment; save_meta::Bool=false) - model_dict = prepare_models(exper) +function train_models!(outcome::ExperimentOutcome, exper::Experiment; save_models::Bool=true, save_meta::Bool=false) + model_dict = prepare_models(exper; save_models=save_models) outcome.model_dict = model_dict if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) meta_model_performance(outcome; save_output=save_meta) @@ -97,10 +97,10 @@ function run_experiment(exper::Experiment; save_output::Bool=true, only_models:: # Model training: if only_models - train_models!(outcome, exper; save_meta=true) + train_models!(outcome, exper; save_models=save_output, save_meta=true) return outcome else - train_models!(outcome, exper) + train_models!(outcome, exper; save_models=save_output) end # Benchmark: diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl index 090acada..e65f8ae8 100644 --- a/experiments/grid_search.jl +++ b/experiments/grid_search.jl @@ -44,5 +44,7 @@ function grid_search( end # Save: - Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_")).jls"), outcomes) + if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) + Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_")).jls"), outcomes) + end end \ No newline at end of file diff --git a/experiments/jobscripts/tuning/generators/synthetic.sh b/experiments/jobscripts/tuning/generators/synthetic.sh index 83744419..4a4e857d 100644 --- a/experiments/jobscripts/tuning/generators/synthetic.sh +++ b/experiments/jobscripts/tuning/generators/synthetic.sh @@ -1,7 +1,7 @@ #!/bin/bash #SBATCH --job-name="Grid-search Synthetic (ECCCo)" -#SBATCH --time=06:00:00 +#SBATCH --time=02:00:00 #SBATCH --ntasks=1000 #SBATCH --cpus-per-task=1 #SBATCH --partition=compute diff --git a/experiments/models/models.jl b/experiments/models/models.jl index 0ce02068..1b0093b2 100644 --- a/experiments/models/models.jl +++ b/experiments/models/models.jl @@ -2,7 +2,7 @@ include("additional_models.jl") include("default_models.jl") include("train_models.jl") -function prepare_models(exper::Experiment) +function prepare_models(exper::Experiment; save_models::Bool=true) # Unpack data: X, labels, sampler = prepare_data(exper::Experiment) @@ -75,7 +75,7 @@ function prepare_models(exper::Experiment) end # Save models: - if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) + if save_models && !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) @info "Saving models to $(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"))." Serialization.serialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"), model_dict) end -- GitLab