Skip to content
Snippets Groups Projects
Commit b92f9a26 authored by pat-alt's avatar pat-alt
Browse files

error accessing learning rate of composed optimisers

parent be8b63d9
No related branches found
No related tags found
1 merge request!8985 overshooting
...@@ -20,7 +20,7 @@ function default_generators(; ...@@ -20,7 +20,7 @@ function default_generators(;
generator_dict = Dict( generator_dict = Dict(
"Wachter" => WachterGenerator(λ = λ₁, opt = opt), "Wachter" => WachterGenerator(λ = λ₁, opt = opt),
"REVISE" => REVISEGenerator(λ = λ₁, opt = opt), "REVISE" => REVISEGenerator(λ = λ₁, opt = opt),
"Schut" => GreedyGenerator(η = opt.eta), "Schut" => GreedyGenerator(η=get_learning_rate(opt)),
"ECCCo" => ECCCoGenerator( "ECCCo" => ECCCoGenerator(
λ = Λ, λ = Λ,
opt = opt, opt = opt,
......
...@@ -38,7 +38,8 @@ function grid_search( ...@@ -38,7 +38,8 @@ function grid_search(
@info "Running experiment $(counter)/$(length(grid)) with tuning parameters: $(params)" @info "Running experiment $(counter)/$(length(grid)) with tuning parameters: $(params)"
# Filter out keyword parameters that are tuned: # Filter out keyword parameters that are tuned:
not_these = keys(kwargs)[findall([k in map(k -> k[1], params) for k in keys(kwargs)])] _keys = [k[1] for k in kwargs]
not_these = _keys[findall([k in map(k -> k[1], params) for k in _keys])]
not_these = (not_these..., :n_individuals) not_these = (not_these..., :n_individuals)
kwargs = filter(x -> !(x[1] not_these), kwargs) kwargs = filter(x -> !(x[1] not_these), kwargs)
......
...@@ -83,7 +83,7 @@ function meta_generators( ...@@ -83,7 +83,7 @@ function meta_generators(
generator_params = DataFrame( generator_params = DataFrame(
Dict( Dict(
:opt => string(typeof(opt)), :opt => string(typeof(opt)),
:eta => opt.eta, :eta => get_learning_rate(opt),
:dataname => exper.dataname, :dataname => exper.dataname,
:lambda_1 => string(Λ[1]), :lambda_1 => string(Λ[1]),
:lambda_2 => string(Λ[2]), :lambda_2 => string(Λ[2]),
......
using CounterfactualExplanations.Parallelization: ThreadsParallelizer using CounterfactualExplanations.Parallelization: ThreadsParallelizer
using Flux
using LinearAlgebra: norm using LinearAlgebra: norm
function is_multi_processed(parallelizer::Union{Nothing,AbstractParallelizer}) function is_multi_processed(parallelizer::Union{Nothing,AbstractParallelizer})
...@@ -25,3 +26,15 @@ function standardize(x::AbstractArray) ...@@ -25,3 +26,15 @@ function standardize(x::AbstractArray)
x_norm = replace(x_norm, NaN => 0.0) x_norm = replace(x_norm, NaN => 0.0)
return x_norm return x_norm
end end
function get_learning_rate(opt::Flux.Optimise.AbstractOptimiser)
if hasfield(typeof(opt), :eta)
return opt.eta
elseif hasfield(typeof(opt), :os)
_os = opt.os
opt = _os[findall([:eta in fieldnames(typeof(o)) for o in _os])][1]
return opt.eta
else
throw(ArgumentError("Cannot find learning rate."))
end
end
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment