From d6608211b15a69235380f768083485da2cbd441e Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Tue, 22 Aug 2023 14:51:54 +0200 Subject: [PATCH] streamlining mostly done --- experiments/post_processing.jl | 119 +++++++++++++++++++++------------ experiments/setup.jl | 1 + 2 files changed, 77 insertions(+), 43 deletions(-) diff --git a/experiments/post_processing.jl b/experiments/post_processing.jl index 36c4b60f..58d17b82 100644 --- a/experiments/post_processing.jl +++ b/experiments/post_processing.jl @@ -1,63 +1,56 @@ """ - meta_model_performance(outcome::ExperimentOutcome; measure=MODEL_MEASURES) + meta(exp::Experiment) -Compute and save the model performance for the models in `outcome.model_dict`. +Extract and save meta data about the experiment. """ -function meta_model_performance(outcome::ExperimentOutcome; measure=MODEL_MEASURES) +function meta(outcome::ExperimentOutcome) - exp = outcome.exp - model_dict = outcome.model_dict + meta_model(outcome) + meta_model_performance(outcome) + meta_generators(outcome) - # Model performance: - model_performance = DataFrame() - for (mod_name, model) in model_dict - # Test performance: - _perf = CounterfactualExplanations.Models.model_evaluation(model, exp.test_data, measure=collect(values(measure))) - _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) - _perf.mod_name .= mod_name - _perf.dataname .= dataname - model_performance = vcat(model_performance, _perf) - end - Serialization.serialize(joinpath(exp.output_path, "$(exp.save_name)_model_performance.jls"), model_performance) - CSV.write(joinpath(exp.output_path, "$(exp.save_name)_model_performance.csv"), model_performance) - @info "Model performance:" - println(model_performance) - return model_performance end """ - meta_data(exp::Experiment) + meta_model(outcome::ExperimentOutcome) -Extract and save meta data about the experiment. +Extract and save meta data about the data and models in `outcome.model_dict`. """ -function meta_data(outcome::ExperimentOutcome) +function meta_model(outcome::ExperimentOutcome) - # Data params: - _, _, n_obs, default_save_name, batch_size, sampler = prepare_data( - exp.counterfactual_data; - ð’Ÿx=exp.ð’Ÿx, - sampling_batch_size=exp.sampling_batch_size - ) - save_name = isnothing(save_name) ? default_save_name : save_name + # Unpack: + exp = outcome.exp + n_obs, batch_size = meta_data(exp) params = DataFrame( Dict( :n_obs => Int.(round(n_obs / 10) * 10), - :epochs => epochs, :batch_size => batch_size, - :n_hidden => n_hidden, - :n_layers => length(model_dict["MLP"].fitresult[1][1]) - 1, - :activation => string(activation), - :n_ens => n_ens, - :lambda => string(α[3]), - :jem_sampling_steps => jem.sampling_steps, - :sgld_batch_size => sampler.batch_size, - :dataname => dataname, + :dataname => exp.dataname, + :sgld_batch_size => exp.sampling_batch_size, + # :epochs => exp.epochs, + # :n_hidden => n_hidden, + # :n_layers => length(model_dict["MLP"].fitresult[1][1]) - 1, + # :activation => string(activation), + # :n_ens => n_ens, + # :lambda => string(α[3]), + # :jem_sampling_steps => jem.sampling_steps, ) ) - if !isnothing(save_path) - CSV.write(joinpath(save_path, "$(save_name)_model_params.csv"), params) - end + + save_path = joinpath(exp.params_path, "$(exp.save_name)_model_params.csv") + @info "Saving model parameters to $(save_path)." + CSV.write(save_path, params) + + return params + +end + +function meta_generators(outcome::ExperimentOutcome) + + # Unpack: + exp = outcome.exp + generator_dict = outcome.generator_dict # Output: opt = first(values(generator_dict)).opt @@ -68,6 +61,46 @@ function meta_data(outcome::ExperimentOutcome) :dataname => dataname, ) ) - CSV.write(joinpath(params_path, "$(save_name)_generator_params.csv"), generator_params) + save_path = joinpath(exp.params_path, "$(exp.save_name)_generator_params.csv") + @info "Saving generator parameters to $(save_path)." + CSV.write(save_path, generator_params) + + return generator_params +end + +""" + meta_model_performance(outcome::ExperimentOutcome; measures=MODEL_MEASURES) + +Compute and save the model performance for the models in `outcome.model_dict`. +""" +function meta_model_performance(outcome::ExperimentOutcome; measures::Union{Nothing, Dict}=nothing) + + # Unpack: + exp = outcome.exp + measures = isnothing(measures) ? exp.model_measures : measures + model_dict = outcome.model_dict + + # Model performance: + model_performance = DataFrame() + for (mod_name, model) in model_dict + # Test performance: + _perf = CounterfactualExplanations.Models.model_evaluation(model, exp.test_data, measure=collect(values(measures))) + _perf = DataFrame([[p] for p in _perf], collect(keys(measures))) + _perf.mod_name .= mod_name + _perf.dataname .= dataname + model_performance = vcat(model_performance, _perf) + end + + @info "Model performance:" + println(model_performance) + + save_path = joinpath(exp.params_path, "$(exp.save_name)_model_performance.jls") + @info "Saving model performance to $(save_path)." + Serialization.serialize(save_path, model_performance) + save_path = joinpath(exp.params_path, "$(exp.save_name)_model_performance.csv") + @info "Saving model performance to $(save_path)." + CSV.write(save_path, model_performance) + + return model_performance end \ No newline at end of file diff --git a/experiments/setup.jl b/experiments/setup.jl index 93834ab8..c8b12cc0 100644 --- a/experiments/setup.jl +++ b/experiments/setup.jl @@ -52,6 +52,7 @@ Base.@kwdef struct Experiment generators::Union{Nothing,Dict} = nothing n_individuals::Int = 50 ce_measures::AbstractArray = CE_MEASURES + model_measures::Dict = MODEL_MEASURES end "A container to hold the results of an experiment." -- GitLab