diff --git a/experiments/california_housing.jl b/experiments/california_housing.jl index 99e598a9596a2f9ff84f3bed7e22c1f7a3676f4f..e8a8db16f50a98d467f8cbb9d8b21d9a5b9b3d8d 100644 --- a/experiments/california_housing.jl +++ b/experiments/california_housing.jl @@ -44,7 +44,7 @@ params = ( ) # Best grid search params: -append_best_params!(params, dataname) +params = append_best_params(params, dataname) if GRID_SEARCH grid_search( diff --git a/experiments/circles.jl b/experiments/circles.jl index f07899065ad3757e2fb457b84fd49c3ebe39682e..0d4a6374afeab3df6432dd03d1a4e49339d2e6df 100644 --- a/experiments/circles.jl +++ b/experiments/circles.jl @@ -29,7 +29,7 @@ params = ( ) # Best grid search params: -append_best_params!(params, dataname) +params = append_best_params(params, dataname) if GRID_SEARCH grid_search( diff --git a/experiments/german_credit.jl b/experiments/german_credit.jl index 082567b2d6fa123feba3627a378ff44bc31c7fce..6a127b07dc31ed445443c5f9ce170669fd18f4b2 100644 --- a/experiments/german_credit.jl +++ b/experiments/german_credit.jl @@ -44,7 +44,7 @@ params = ( ) # Best grid search params: -append_best_params!(params, dataname) +params = append_best_params(params, dataname) if GRID_SEARCH grid_search( diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl index 3c55d0a585569b8a0712b8afb46489b38b92d721..2b93ff4e5aa4d846f9ae76411e7a470a16d84c4b 100644 --- a/experiments/gmsc.jl +++ b/experiments/gmsc.jl @@ -44,7 +44,7 @@ params = ( ) # Best grid search params: -append_best_params!(params, dataname) +params = append_best_params(params, dataname) if GRID_SEARCH grid_search( diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl index 0c11b9f293b365bb4e833ab7fb317090ea2bf1fb..fff7c41427728660de991179345a825abf239cff 100644 --- a/experiments/grid_search.jl +++ b/experiments/grid_search.jl @@ -56,9 +56,9 @@ function grid_search( ) # Collect: - params = map(x -> typeof(x[2]) <: Vector ? x[1] => Tuple(x[2]) : x[1] => x[2], params) + _params = map(x -> typeof(x[2]) <: Vector ? x[1] => Tuple(x[2]) : x[1] => x[2], params) df_params = - DataFrame(merge(Dict(:id => counter), Dict(params))) |> + DataFrame(merge(Dict(:id => counter), Dict(_params))) |> x -> select(x, :id, Not(:id)) df_outcomes = DataFrame(Dict(:id => counter, :params => params, :outcome => outcome)) |> @@ -236,7 +236,7 @@ best_outcome(outcomes; measure=["distance_from_energy_l2"]) = best_absolute_outc Appends the best parameters from grid search results to the specified parameters. """ -function append_best_params!(params::NamedTuple, dataname::String) +function append_best_params(params::NamedTuple, dataname::String) if !isfile( joinpath( DEFAULT_OUTPUT_PATH, @@ -256,6 +256,8 @@ function append_best_params!(params::NamedTuple, dataname::String) ) best_params = best_outcome(grid_search_results).params params = (; params..., best_params...) - @info "Best parameters: $(best_params)" + + params = (; params..., (; Λ = typeof(params.Λ) <: Tuple ? collect(params.Λ) : params.Λ)...) end + return params end diff --git a/experiments/linearly_separable.jl b/experiments/linearly_separable.jl index 2421bd9a9beef61da801b325ebec1a45ab986a2a..43554f3647d8dc92dc5ae03201044faadf84bca9 100644 --- a/experiments/linearly_separable.jl +++ b/experiments/linearly_separable.jl @@ -29,7 +29,8 @@ params = ( ) # Best grid search params: -append_best_params!(params, dataname) +params = append_best_params(params, dataname) +@info "Using the following parameters: $(params)" if GRID_SEARCH grid_search( diff --git a/experiments/moons.jl b/experiments/moons.jl index a15002795c7a4d135eba54c7c6e8088098fa193d..809bdc406ace3d45f84f85d118561803ccbeebca 100644 --- a/experiments/moons.jl +++ b/experiments/moons.jl @@ -28,7 +28,7 @@ params = ( ) # Best grid search params: -append_best_params!(params, dataname) +params = append_best_params(params, dataname) if GRID_SEARCH grid_search(