From b92f9a26bbcecb565487540814cbcea43cecde97 Mon Sep 17 00:00:00 2001 From: pat-alt <altmeyerpat@gmail.com> Date: Wed, 18 Oct 2023 12:32:32 +0200 Subject: [PATCH] error accessing learning rate of composed optimisers --- experiments/benchmarking/benchmarking.jl | 2 +- experiments/grid_search.jl | 3 ++- experiments/post_processing/meta_data.jl | 2 +- experiments/utils.jl | 13 +++++++++++++ 4 files changed, 17 insertions(+), 3 deletions(-) diff --git a/experiments/benchmarking/benchmarking.jl b/experiments/benchmarking/benchmarking.jl index f07bd6db..c3bacbce 100644 --- a/experiments/benchmarking/benchmarking.jl +++ b/experiments/benchmarking/benchmarking.jl @@ -20,7 +20,7 @@ function default_generators(; generator_dict = Dict( "Wachter" => WachterGenerator(λ = λâ‚, opt = opt), "REVISE" => REVISEGenerator(λ = λâ‚, opt = opt), - "Schut" => GreedyGenerator(η = opt.eta), + "Schut" => GreedyGenerator(η=get_learning_rate(opt)), "ECCCo" => ECCCoGenerator( λ = Λ, opt = opt, diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl index d5d2455c..36904a08 100644 --- a/experiments/grid_search.jl +++ b/experiments/grid_search.jl @@ -38,7 +38,8 @@ function grid_search( @info "Running experiment $(counter)/$(length(grid)) with tuning parameters: $(params)" # Filter out keyword parameters that are tuned: - not_these = keys(kwargs)[findall([k in map(k -> k[1], params) for k in keys(kwargs)])] + _keys = [k[1] for k in kwargs] + not_these = _keys[findall([k in map(k -> k[1], params) for k in _keys])] not_these = (not_these..., :n_individuals) kwargs = filter(x -> !(x[1] ∈ not_these), kwargs) diff --git a/experiments/post_processing/meta_data.jl b/experiments/post_processing/meta_data.jl index ab4c982a..0acd94c3 100644 --- a/experiments/post_processing/meta_data.jl +++ b/experiments/post_processing/meta_data.jl @@ -83,7 +83,7 @@ function meta_generators( generator_params = DataFrame( Dict( :opt => string(typeof(opt)), - :eta => opt.eta, + :eta => get_learning_rate(opt), :dataname => exper.dataname, :lambda_1 => string(Λ[1]), :lambda_2 => string(Λ[2]), diff --git a/experiments/utils.jl b/experiments/utils.jl index 2f9b6845..be94743d 100644 --- a/experiments/utils.jl +++ b/experiments/utils.jl @@ -1,4 +1,5 @@ using CounterfactualExplanations.Parallelization: ThreadsParallelizer +using Flux using LinearAlgebra: norm function is_multi_processed(parallelizer::Union{Nothing,AbstractParallelizer}) @@ -25,3 +26,15 @@ function standardize(x::AbstractArray) x_norm = replace(x_norm, NaN => 0.0) return x_norm end + +function get_learning_rate(opt::Flux.Optimise.AbstractOptimiser) + if hasfield(typeof(opt), :eta) + return opt.eta + elseif hasfield(typeof(opt), :os) + _os = opt.os + opt = _os[findall([:eta in fieldnames(typeof(o)) for o in _os])][1] + return opt.eta + else + throw(ArgumentError("Cannot find learning rate.")) + end +end \ No newline at end of file -- GitLab