diff --git a/experiments/benchmarking/benchmarking.jl b/experiments/benchmarking/benchmarking.jl index f07bd6dbbd541814d7ef9225f82cee8988ab7509..c3bacbce7018e2e3e3721d2fdb28a5302d355482 100644 --- a/experiments/benchmarking/benchmarking.jl +++ b/experiments/benchmarking/benchmarking.jl @@ -20,7 +20,7 @@ function default_generators(; generator_dict = Dict( "Wachter" => WachterGenerator(λ = λâ‚, opt = opt), "REVISE" => REVISEGenerator(λ = λâ‚, opt = opt), - "Schut" => GreedyGenerator(η = opt.eta), + "Schut" => GreedyGenerator(η=get_learning_rate(opt)), "ECCCo" => ECCCoGenerator( λ = Λ, opt = opt, diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl index d5d2455c0b9d1535cac967416258189468df409e..36904a0815b8dafdc8c17a6ef275966f64eff75e 100644 --- a/experiments/grid_search.jl +++ b/experiments/grid_search.jl @@ -38,7 +38,8 @@ function grid_search( @info "Running experiment $(counter)/$(length(grid)) with tuning parameters: $(params)" # Filter out keyword parameters that are tuned: - not_these = keys(kwargs)[findall([k in map(k -> k[1], params) for k in keys(kwargs)])] + _keys = [k[1] for k in kwargs] + not_these = _keys[findall([k in map(k -> k[1], params) for k in _keys])] not_these = (not_these..., :n_individuals) kwargs = filter(x -> !(x[1] ∈ not_these), kwargs) diff --git a/experiments/post_processing/meta_data.jl b/experiments/post_processing/meta_data.jl index ab4c982a14ff886e16a5cf7ad59edd86db945700..0acd94c3f835b21fc8d07acf2e6b687e0e923c01 100644 --- a/experiments/post_processing/meta_data.jl +++ b/experiments/post_processing/meta_data.jl @@ -83,7 +83,7 @@ function meta_generators( generator_params = DataFrame( Dict( :opt => string(typeof(opt)), - :eta => opt.eta, + :eta => get_learning_rate(opt), :dataname => exper.dataname, :lambda_1 => string(Λ[1]), :lambda_2 => string(Λ[2]), diff --git a/experiments/utils.jl b/experiments/utils.jl index 2f9b68450b109e750bfee8ba2f17cdc55b28f6bc..be94743d216d6b1ccb07e5094f4a5913ceddeff1 100644 --- a/experiments/utils.jl +++ b/experiments/utils.jl @@ -1,4 +1,5 @@ using CounterfactualExplanations.Parallelization: ThreadsParallelizer +using Flux using LinearAlgebra: norm function is_multi_processed(parallelizer::Union{Nothing,AbstractParallelizer}) @@ -25,3 +26,15 @@ function standardize(x::AbstractArray) x_norm = replace(x_norm, NaN => 0.0) return x_norm end + +function get_learning_rate(opt::Flux.Optimise.AbstractOptimiser) + if hasfield(typeof(opt), :eta) + return opt.eta + elseif hasfield(typeof(opt), :os) + _os = opt.os + opt = _os[findall([:eta in fieldnames(typeof(o)) for o in _os])][1] + return opt.eta + else + throw(ArgumentError("Cannot find learning rate.")) + end +end \ No newline at end of file