Skip to content
Snippets Groups Projects
experiment.jl 5.94 KiB
"Sets up the experiment."
Base.@kwdef struct Experiment
    counterfactual_data::CounterfactualData
    test_data::CounterfactualData
    dataname::String = "dataset"
    save_name::String = replace(lowercase(dataname), " " => "_")
    output_path::String = DEFAULT_OUTPUT_PATH
    params_path::String = joinpath(output_path, "params")
    use_pretrained::Bool = !RETRAIN
    models::Union{Nothing,Dict} = nothing
    additional_models::Union{Nothing,Dict} = nothing
    𝒟x::Distribution = ECCCo.prior_sampling_space(counterfactual_data)
    sampling_batch_size::Int = 50
    sampling_steps::Int = 50
    min_batch_size::Int = 128
    epochs::Int = 100
    n_hidden::Int = 32
    n_layers::Int = 3
    activation::Function = Flux.relu
    builder::Union{Nothing,MLJFlux.Builder} =
        default_builder(n_hidden = n_hidden, n_layers = n_layers, activation = activation)
    α::AbstractArray = [1.0, 1.0, 1e-1]
    n_ens::Int = 5
    use_ensembling::Bool = true
    coverage::Float64 = DEFAULT_COVERAGE
    generators::Union{Nothing,Dict} = nothing
    n_individuals::Int = N_IND
    n_runs::Int = N_RUNS
    ce_measures::AbstractArray = CE_MEASURES
    model_measures::Dict = MODEL_MEASURES
    use_class_loss::Bool = false
    use_variants::Bool = true
    Λ::AbstractArray = [0.25, 0.75, 0.75]
    Λ_Δ::AbstractArray = Λ
    opt::Flux.Optimise.AbstractOptimiser = Flux.Optimise.Descent(0.01)
    parallelizer::Union{Nothing,AbstractParallelizer} = PLZ
    nsamples::Union{Nothing,Int} = nothing
    nmin::Union{Nothing,Int} = nothing
    finaliser::Function = Flux.softmax
    loss::Function = Flux.Losses.crossentropy
    train_parallel::Bool = false
    reg_strength::Real = 0.1
    decay::Tuple = (0.1, 5)
    niter_eccco::Union{Nothing,Int} = nothing
    model_tuning_params::NamedTuple = DEFAULT_MODEL_TUNING_SMALL
    use_tuned::Bool = true
    store_ce::Bool = STORE_CE
    dim_reduction::Bool = false
end

"A container to hold the results of an experiment."
mutable struct ExperimentOutcome
    exper::Experiment
    model_dict::Union{Nothing,Dict}
    generator_dict::Union{Nothing,Dict}
    bmk::Union{Nothing,Benchmark}
end

"""
    train_models!(outcome::ExperimentOutcome, exper::Experiment)

Train the models specified by `exper` and store them in `outcome`.
"""
function train_models!(
    outcome::ExperimentOutcome,
    exper::Experiment;
    save_models::Bool = true,
    save_meta::Bool = false,
)
    model_dict = prepare_models(exper; save_models = save_models)
    outcome.model_dict = model_dict
    if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
        meta_model_performance(outcome; save_output = save_meta)
    end
end

"""
    benchmark!(outcome::ExperimentOutcome, exper::Experiment)

Benchmark the models specified by `exper` and store the results in `outcome`.
"""
function benchmark!(outcome::ExperimentOutcome, exper::Experiment)
    bmk, generator_dict = run_benchmark(exper, outcome.model_dict)
    outcome.generator_dict = generator_dict
    outcome.bmk = bmk
end

"""
    run_experiment(exper::Experiment)

Run the experiment specified by `exper`.
"""
function run_experiment(
    exper::Experiment;
    save_output::Bool = true,
    only_models::Bool = ONLY_MODELS,
)

    # Setup
    if save_output &&
       !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
        @info "All results will be saved to $(exper.output_path)."
        isdir(exper.output_path) || mkdir(exper.output_path)
        @info "All parameter choices will be saved to $(exper.params_path)."
        isdir(exper.params_path) || mkdir(exper.params_path)
    end

    # Run the experiment:
    outcome = ExperimentOutcome(exper, nothing, nothing, nothing)

    # Model tuning:
    if TUNE_MODEL
        mach = tune_mlp(exper)
        return mach
    end

    # Model training:
    if only_models
        train_models!(outcome, exper; save_models = save_output, save_meta = true)
        return outcome
    else
        train_models!(outcome, exper; save_models = save_output)
    end

    # Benchmark:
    benchmark!(outcome, exper)
    if is_multi_processed(exper)
        MPI.Barrier(exper.parallelizer.comm)
    end

    # Save data:
    if save_output &&
       !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
        Serialization.serialize(
            joinpath(exper.output_path, "$(exper.save_name)_outcome.jls"),
            outcome,
        )
        Serialization.serialize(
            joinpath(exper.output_path, "$(exper.save_name)_bmk.jls"),
            outcome.bmk,
        )
        all_meta(outcome; save_output = true)
    end

    # Final barrier:
    if is_multi_processed(exper)
        MPI.Barrier(exper.parallelizer.comm)
    end

    return outcome

end

"""
    run_experiment(counterfactual_data::CounterfactualData, test_data::CounterfactualData; kwargs...)

Overload the `run_experiment` function to allow for passing in `CounterfactualData` objects and other keyword arguments.
"""
function run_experiment(
    counterfactual_data::CounterfactualData,
    test_data::CounterfactualData;
    save_output::Bool = true,
    kwargs...,
)
    # Parameters:
    exper = Experiment(;
        counterfactual_data = counterfactual_data,
        test_data = test_data,
        kwargs...,
    )
    return run_experiment(exper; save_output = save_output)
end

# Pre-trained models:
function pretrained_path(exper::Experiment)
    if isfile(joinpath(DEFAULT_OUTPUT_PATH, "$(exper.save_name)_models.jls"))
        @info "Found local pre-trained models in $(DEFAULT_OUTPUT_PATH) and using those."
        return DEFAULT_OUTPUT_PATH
    elseif isfile("models/$(exper.save_name)_models.jls")
        @info "Found local pre-trained models in models/ and using those."
        return "models"
    else
        @info "Using artifacts. Models were pre-trained on `julia-$(LATEST_VERSION)` and may not work on other versions."
        Pkg.Artifacts.download_artifact(ARTIFACT_HASH, ARTIFACT_TOML)
        return joinpath(LATEST_ARTIFACT_PATH, "results")
    end
end