From b92f9a26bbcecb565487540814cbcea43cecde97 Mon Sep 17 00:00:00 2001
From: pat-alt <altmeyerpat@gmail.com>
Date: Wed, 18 Oct 2023 12:32:32 +0200
Subject: [PATCH] error accessing learning rate of composed optimisers

---
 experiments/benchmarking/benchmarking.jl |  2 +-
 experiments/grid_search.jl               |  3 ++-
 experiments/post_processing/meta_data.jl |  2 +-
 experiments/utils.jl                     | 13 +++++++++++++
 4 files changed, 17 insertions(+), 3 deletions(-)

diff --git a/experiments/benchmarking/benchmarking.jl b/experiments/benchmarking/benchmarking.jl
index f07bd6db..c3bacbce 100644
--- a/experiments/benchmarking/benchmarking.jl
+++ b/experiments/benchmarking/benchmarking.jl
@@ -20,7 +20,7 @@ function default_generators(;
         generator_dict = Dict(
             "Wachter" => WachterGenerator(λ = λ₁, opt = opt),
             "REVISE" => REVISEGenerator(λ = λ₁, opt = opt),
-            "Schut" => GreedyGenerator(η = opt.eta),
+            "Schut" => GreedyGenerator(η=get_learning_rate(opt)),
             "ECCCo" => ECCCoGenerator(
                 λ = Λ,
                 opt = opt,
diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl
index d5d2455c..36904a08 100644
--- a/experiments/grid_search.jl
+++ b/experiments/grid_search.jl
@@ -38,7 +38,8 @@ function grid_search(
         @info "Running experiment $(counter)/$(length(grid)) with tuning parameters: $(params)"
 
         # Filter out keyword parameters that are tuned:
-        not_these = keys(kwargs)[findall([k in map(k -> k[1], params) for k in keys(kwargs)])]
+        _keys = [k[1] for k in kwargs]
+        not_these = _keys[findall([k in map(k -> k[1], params) for k in _keys])]
         not_these = (not_these..., :n_individuals)
         kwargs = filter(x -> !(x[1] ∈ not_these), kwargs)
 
diff --git a/experiments/post_processing/meta_data.jl b/experiments/post_processing/meta_data.jl
index ab4c982a..0acd94c3 100644
--- a/experiments/post_processing/meta_data.jl
+++ b/experiments/post_processing/meta_data.jl
@@ -83,7 +83,7 @@ function meta_generators(
     generator_params = DataFrame(
         Dict(
             :opt => string(typeof(opt)),
-            :eta => opt.eta,
+            :eta => get_learning_rate(opt),
             :dataname => exper.dataname,
             :lambda_1 => string(Λ[1]),
             :lambda_2 => string(Λ[2]),
diff --git a/experiments/utils.jl b/experiments/utils.jl
index 2f9b6845..be94743d 100644
--- a/experiments/utils.jl
+++ b/experiments/utils.jl
@@ -1,4 +1,5 @@
 using CounterfactualExplanations.Parallelization: ThreadsParallelizer
+using Flux
 using LinearAlgebra: norm
 
 function is_multi_processed(parallelizer::Union{Nothing,AbstractParallelizer})
@@ -25,3 +26,15 @@ function standardize(x::AbstractArray)
     x_norm = replace(x_norm, NaN => 0.0)
     return x_norm
 end
+
+function get_learning_rate(opt::Flux.Optimise.AbstractOptimiser)
+    if hasfield(typeof(opt), :eta)
+        return opt.eta
+    elseif hasfield(typeof(opt), :os)
+        _os = opt.os
+        opt = _os[findall([:eta in fieldnames(typeof(o)) for o in _os])][1]
+        return opt.eta
+    else 
+        throw(ArgumentError("Cannot find learning rate."))
+    end
+end
\ No newline at end of file
-- 
GitLab