"Sets up the experiment." Base.@kwdef struct Experiment counterfactual_data::CounterfactualData test_data::CounterfactualData dataname::String = "dataset" save_name::String = replace(lowercase(dataname), " " => "_") output_path::String = DEFAULT_OUTPUT_PATH params_path::String = joinpath(output_path, "params") use_pretrained::Bool = !RETRAIN models::Union{Nothing,Dict} = nothing additional_models::Union{Nothing,Dict} = nothing 𝒟x::Distribution = Normal() sampling_batch_size::Int = 50 sampling_steps::Int = 50 min_batch_size::Int = 128 epochs::Int = 100 n_hidden::Int = 32 n_layers::Int = 3 activation::Function = Flux.relu builder::Union{Nothing,MLJFlux.Builder} = default_builder(n_hidden = n_hidden, n_layers = n_layers, activation = activation) α::AbstractArray = [1.0, 1.0, 1e-1] n_ens::Int = 5 use_ensembling::Bool = true coverage::Float64 = DEFAULT_COVERAGE generators::Union{Nothing,Dict} = nothing n_individuals::Int = N_IND n_runs::Int = N_RUNS ce_measures::AbstractArray = CE_MEASURES model_measures::Dict = MODEL_MEASURES use_class_loss::Bool = false use_variants::Bool = true Λ::AbstractArray = [0.25, 0.75, 0.75] Λ_Δ::AbstractArray = Λ opt::Flux.Optimise.AbstractOptimiser = Flux.Optimise.Descent(0.01) parallelizer::Union{Nothing,AbstractParallelizer} = PLZ nsamples::Union{Nothing,Int} = nothing nmin::Union{Nothing,Int} = nothing finaliser::Function = Flux.softmax loss::Function = Flux.Losses.crossentropy train_parallel::Bool = false reg_strength::Real = 0.1 decay::Tuple = (0.1, 5) niter_eccco::Union{Nothing,Int} = nothing model_tuning_params::NamedTuple = DEFAULT_MODEL_TUNING_SMALL use_tuned::Bool = true store_ce::Bool = STORE_CE dim_reduction::Bool = false end "A container to hold the results of an experiment." mutable struct ExperimentOutcome exper::Experiment model_dict::Union{Nothing,Dict} generator_dict::Union{Nothing,Dict} bmk::Union{Nothing,Benchmark} end """ train_models!(outcome::ExperimentOutcome, exper::Experiment) Train the models specified by `exper` and store them in `outcome`. """ function train_models!( outcome::ExperimentOutcome, exper::Experiment; save_models::Bool = true, save_meta::Bool = false, ) model_dict = prepare_models(exper; save_models = save_models) outcome.model_dict = model_dict if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) meta_model_performance(outcome; save_output = save_meta) end end """ benchmark!(outcome::ExperimentOutcome, exper::Experiment) Benchmark the models specified by `exper` and store the results in `outcome`. """ function benchmark!(outcome::ExperimentOutcome, exper::Experiment) bmk, generator_dict = run_benchmark(exper, outcome.model_dict) outcome.generator_dict = generator_dict outcome.bmk = bmk end """ run_experiment(exper::Experiment) Run the experiment specified by `exper`. """ function run_experiment( exper::Experiment; save_output::Bool = true, only_models::Bool = ONLY_MODELS, ) # Setup if save_output && !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) @info "All results will be saved to $(exper.output_path)." isdir(exper.output_path) || mkdir(exper.output_path) @info "All parameter choices will be saved to $(exper.params_path)." isdir(exper.params_path) || mkdir(exper.params_path) end # Run the experiment: outcome = ExperimentOutcome(exper, nothing, nothing, nothing) # Model tuning: if TUNE_MODEL mach = tune_mlp(exper) return mach end # Model training: if only_models train_models!(outcome, exper; save_models = save_output, save_meta = true) return outcome else train_models!(outcome, exper; save_models = save_output) end # Benchmark: benchmark!(outcome, exper) if is_multi_processed(exper) MPI.Barrier(exper.parallelizer.comm) end # Save data: if save_output && !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) Serialization.serialize( joinpath(exper.output_path, "$(exper.save_name)_outcome.jls"), outcome, ) Serialization.serialize( joinpath(exper.output_path, "$(exper.save_name)_bmk.jls"), outcome.bmk, ) all_meta(outcome; save_output = true) end # Final barrier: if is_multi_processed(exper) MPI.Barrier(exper.parallelizer.comm) end return outcome end """ run_experiment(counterfactual_data::CounterfactualData, test_data::CounterfactualData; kwargs...) Overload the `run_experiment` function to allow for passing in `CounterfactualData` objects and other keyword arguments. """ function run_experiment( counterfactual_data::CounterfactualData, test_data::CounterfactualData; save_output::Bool = true, kwargs..., ) # Parameters: exper = Experiment(; counterfactual_data = counterfactual_data, test_data = test_data, kwargs..., ) return run_experiment(exper; save_output = save_output) end # Pre-trained models: function pretrained_path(exper::Experiment) if isfile(joinpath(DEFAULT_OUTPUT_PATH, "$(exper.save_name)_models.jls")) @info "Found local pre-trained models in $(DEFAULT_OUTPUT_PATH) and using those." return DEFAULT_OUTPUT_PATH elseif isfile("models/$(exper.save_name)_models.jls") @info "Found local pre-trained models in models/ and using those." return "models" else @info "Using artifacts. Models were pre-trained on `julia-$(LATEST_VERSION)` and may not work on other versions." Pkg.Artifacts.download_artifact(ARTIFACT_HASH, ARTIFACT_TOML) return joinpath(LATEST_ARTIFACT_PATH, "results") end end