Skip to content
Snippets Groups Projects
Commit d6608211 authored by Pat Alt's avatar Pat Alt
Browse files

streamlining mostly done

parent 7a2a2853
No related branches found
No related tags found
1 merge request!68Post rebuttal
""" """
meta_model_performance(outcome::ExperimentOutcome; measure=MODEL_MEASURES) meta(exp::Experiment)
Compute and save the model performance for the models in `outcome.model_dict`. Extract and save meta data about the experiment.
""" """
function meta_model_performance(outcome::ExperimentOutcome; measure=MODEL_MEASURES) function meta(outcome::ExperimentOutcome)
exp = outcome.exp meta_model(outcome)
model_dict = outcome.model_dict meta_model_performance(outcome)
meta_generators(outcome)
# Model performance:
model_performance = DataFrame()
for (mod_name, model) in model_dict
# Test performance:
_perf = CounterfactualExplanations.Models.model_evaluation(model, exp.test_data, measure=collect(values(measure)))
_perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
_perf.mod_name .= mod_name
_perf.dataname .= dataname
model_performance = vcat(model_performance, _perf)
end
Serialization.serialize(joinpath(exp.output_path, "$(exp.save_name)_model_performance.jls"), model_performance)
CSV.write(joinpath(exp.output_path, "$(exp.save_name)_model_performance.csv"), model_performance)
@info "Model performance:"
println(model_performance)
return model_performance
end end
""" """
meta_data(exp::Experiment) meta_model(outcome::ExperimentOutcome)
Extract and save meta data about the experiment. Extract and save meta data about the data and models in `outcome.model_dict`.
""" """
function meta_data(outcome::ExperimentOutcome) function meta_model(outcome::ExperimentOutcome)
# Data params: # Unpack:
_, _, n_obs, default_save_name, batch_size, sampler = prepare_data( exp = outcome.exp
exp.counterfactual_data; n_obs, batch_size = meta_data(exp)
𝒟x=exp.𝒟x,
sampling_batch_size=exp.sampling_batch_size
)
save_name = isnothing(save_name) ? default_save_name : save_name
params = DataFrame( params = DataFrame(
Dict( Dict(
:n_obs => Int.(round(n_obs / 10) * 10), :n_obs => Int.(round(n_obs / 10) * 10),
:epochs => epochs,
:batch_size => batch_size, :batch_size => batch_size,
:n_hidden => n_hidden, :dataname => exp.dataname,
:n_layers => length(model_dict["MLP"].fitresult[1][1]) - 1, :sgld_batch_size => exp.sampling_batch_size,
:activation => string(activation), # :epochs => exp.epochs,
:n_ens => n_ens, # :n_hidden => n_hidden,
:lambda => string(α[3]), # :n_layers => length(model_dict["MLP"].fitresult[1][1]) - 1,
:jem_sampling_steps => jem.sampling_steps, # :activation => string(activation),
:sgld_batch_size => sampler.batch_size, # :n_ens => n_ens,
:dataname => dataname, # :lambda => string(α[3]),
# :jem_sampling_steps => jem.sampling_steps,
) )
) )
if !isnothing(save_path)
CSV.write(joinpath(save_path, "$(save_name)_model_params.csv"), params) save_path = joinpath(exp.params_path, "$(exp.save_name)_model_params.csv")
end @info "Saving model parameters to $(save_path)."
CSV.write(save_path, params)
return params
end
function meta_generators(outcome::ExperimentOutcome)
# Unpack:
exp = outcome.exp
generator_dict = outcome.generator_dict
# Output: # Output:
opt = first(values(generator_dict)).opt opt = first(values(generator_dict)).opt
...@@ -68,6 +61,46 @@ function meta_data(outcome::ExperimentOutcome) ...@@ -68,6 +61,46 @@ function meta_data(outcome::ExperimentOutcome)
:dataname => dataname, :dataname => dataname,
) )
) )
CSV.write(joinpath(params_path, "$(save_name)_generator_params.csv"), generator_params)
save_path = joinpath(exp.params_path, "$(exp.save_name)_generator_params.csv")
@info "Saving generator parameters to $(save_path)."
CSV.write(save_path, generator_params)
return generator_params
end
"""
meta_model_performance(outcome::ExperimentOutcome; measures=MODEL_MEASURES)
Compute and save the model performance for the models in `outcome.model_dict`.
"""
function meta_model_performance(outcome::ExperimentOutcome; measures::Union{Nothing, Dict}=nothing)
# Unpack:
exp = outcome.exp
measures = isnothing(measures) ? exp.model_measures : measures
model_dict = outcome.model_dict
# Model performance:
model_performance = DataFrame()
for (mod_name, model) in model_dict
# Test performance:
_perf = CounterfactualExplanations.Models.model_evaluation(model, exp.test_data, measure=collect(values(measures)))
_perf = DataFrame([[p] for p in _perf], collect(keys(measures)))
_perf.mod_name .= mod_name
_perf.dataname .= dataname
model_performance = vcat(model_performance, _perf)
end
@info "Model performance:"
println(model_performance)
save_path = joinpath(exp.params_path, "$(exp.save_name)_model_performance.jls")
@info "Saving model performance to $(save_path)."
Serialization.serialize(save_path, model_performance)
save_path = joinpath(exp.params_path, "$(exp.save_name)_model_performance.csv")
@info "Saving model performance to $(save_path)."
CSV.write(save_path, model_performance)
return model_performance
end end
\ No newline at end of file
...@@ -52,6 +52,7 @@ Base.@kwdef struct Experiment ...@@ -52,6 +52,7 @@ Base.@kwdef struct Experiment
generators::Union{Nothing,Dict} = nothing generators::Union{Nothing,Dict} = nothing
n_individuals::Int = 50 n_individuals::Int = 50
ce_measures::AbstractArray = CE_MEASURES ce_measures::AbstractArray = CE_MEASURES
model_measures::Dict = MODEL_MEASURES
end end
"A container to hold the results of an experiment." "A container to hold the results of an experiment."
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment