Skip to content
Snippets Groups Projects
Commit f53fe37b authored by Pat Alt's avatar Pat Alt
Browse files

adjusted method to choose best from grid search

parent 5e8fbab2
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
......@@ -91,26 +91,33 @@ function run_experiment(exper::Experiment; save_output::Bool=true, only_models::
@info "All parameter choices will be saved to $(exper.params_path)."
isdir(exper.params_path) || mkdir(exper.params_path)
end
outcome = ExperimentOutcome(exper, nothing, nothing, nothing)
# Model tuning:
if TUNE_MODEL
mach = tune_mlp(exper)
return mach
end
# Model training:
if only_models
train_models!(outcome, exper; save_models=save_output, save_meta=true)
return outcome
if FROM_GRID_SEARCH
# Just load the best model from the grid search:
outcome = Serialization.deserialize(joinpath(exper.output_path, "grid_search", "$(exper.save_name)_best_eccco_delta.jls"))
else
train_models!(outcome, exper; save_models=save_output)
end
# Benchmark:
benchmark!(outcome, exper)
if is_multi_processed(exper)
MPI.Barrier(exper.parallelizer.comm)
# Run the experiment:
outcome = ExperimentOutcome(exper, nothing, nothing, nothing)
# Model tuning:
if TUNE_MODEL
mach = tune_mlp(exper)
return mach
end
# Model training:
if only_models
train_models!(outcome, exper; save_models=save_output, save_meta=true)
return outcome
else
train_models!(outcome, exper; save_models=save_output)
end
# Benchmark:
benchmark!(outcome, exper)
if is_multi_processed(exper)
MPI.Barrier(exper.parallelizer.comm)
end
end
# Save data:
......
......@@ -78,12 +78,11 @@ const ECCCo_Δ_NAMES = [
Returns the best outcome from grid search results. The best outcome is defined as the one with the lowest average rank across all datasets and variables for the specified generator and measure.
"""
function best_outcome(
function best_rank_outcome(
outcomes::Dict;
generator=ALL_ECCCO_NAMES,
measure=["distance_from_energy_l2", "distance_from_targets_l2"],
model::Union{Nothing,AbstractArray}=nothing,
weights::Union{Nothing,AbstractArray}=nothing
)
weights = isnothing(weights) ? ones(length(measure)) : weights
......@@ -105,30 +104,60 @@ function best_outcome(
return best_outcome
end
best_eccco(outcomes; kwrgs...) = best_outcome(outcomes; generator=ECCCO_NAMES, kwrgs...)
best_rank_eccco(outcomes; kwrgs...) = best_outcome(outcomes; generator=ECCCO_NAMES, kwrgs...)
best_eccco_Δ(outcomes; kwrgs...) = best_outcome(outcomes; generator=ECCCo_Δ_NAMES, kwrgs...)
best_rank_eccco_Δ(outcomes; kwrgs...) = best_outcome(outcomes; generator=ECCCo_Δ_NAMES, kwrgs...)
"""
best_absolute_outcome(outcomes; generator=ECCCO_NAMES, measure="distance_from_energy")
Return the best outcome from grid search results. The best outcome is defined as the one with the lowest average value across all datasets and variables for the specified generator and measure.
"""
function best_absolute_outcome(outcomes::Dict; generator=ECCCO_NAMES, measure::String="distance_from_energy_l2", model::Union{Nothing,AbstractArray}=nothing)
function best_absolute_outcome(
outcomes::Dict;
generator=ECCCO_NAMES,
measure::AbstractArray=["distance_from_targets_l2", "distance_from_energy_l2"],
model::Union{Nothing,AbstractArray}=nothing,
weights::Union{Nothing,AbstractArray}=nothing
)
weights = isnothing(weights) ? ones(length(measure)) : weights
df_weights = DataFrame(variable=measure, weight=weights)
avg_values = []
for (params, outcome) in outcomes
# Compute:
results = summarise_outcome(outcome, measure=[measure], model=model)
# Setup
evaluation = outcome.bmk.evaluation
exper = outcome.exper
generator_dict = outcome.generator_dict
model_dict = outcome.model_dict
# Adjust variables for which higher is better:
higher_is_better = [var ["validity", "redundancy"] for var in results.variable]
results.mean[higher_is_better] .= -results.mean[higher_is_better]
# Compute avergaes:
higher_is_better = [var ["validity", "redundancy"] for var in evaluation.variable]
evaluation.value[higher_is_better] .= -evaluation.value[higher_is_better]
# Normalise to allow for comparison across measures:
evaluation =
groupby(evaluation, [:dataname, :variable]) |>
x -> transform(x, :value => standardize => :value)
# Reconstruct outcome with normalised values:
bmk = CounterfactualExplanations.Evaluation.Benchmark(evaluation)
outcome = ExperimentOutcome(exper, model_dict, generator_dict, bmk)
# Compute:
results = summarise_outcome(outcome, measure=measure, model=model) |>
x -> leftjoin(x, df_weights, on=:variable)
# Compute weighted averages:
_avg_values = subset(results, :generator => ByRow(x -> x generator)) |>
x -> x.mean |>
x -> x.mean .* x.weight |>
x -> (sum(x)/length(x))[1]
# Append:
push!(avg_values, _avg_values)
end
println(avg_values)
best_index = argmin(avg_values)
best_outcome = (
params = collect(keys(outcomes))[best_index],
......
......@@ -195,3 +195,6 @@ DEFAULT_MODEL_TUNING_LARGE = (
"Boolean flag to check if store counterfactual explanations was specified."
STORE_CE = "store_ce" ARGS
"Boolean flag to chech if best outcome from grid search should be used."
FROM_GRID_SEARCH = "from_grid" ARGS
using CounterfactualExplanations.Parallelization: ThreadsParallelizer
using LinearAlgebra: norm
function is_multi_processed(parallelizer::Union{Nothing,AbstractParallelizer})
if isnothing(parallelizer) || isa(parallelizer, ThreadsParallelizer)
......@@ -8,4 +9,16 @@ function is_multi_processed(parallelizer::Union{Nothing,AbstractParallelizer})
end
end
is_multi_processed(exper::Experiment) = is_multi_processed(exper.parallelizer)
\ No newline at end of file
is_multi_processed(exper::Experiment) = is_multi_processed(exper.parallelizer)
function min_max_scale(x::AbstractArray)
x_norm = (x .- minimum(x)) ./ (maximum(x) - minimum(x))
x_norm = replace(x_norm, NaN => 0.0)
return x_norm
end
function standardize(x::AbstractArray)
x_norm = (x .- sum(x)/length(x)) ./ std(x)
x_norm = replace(x_norm, NaN => 0.0)
return x_norm
end
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment