diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl index ad0f9daa9937239c9096fe61c438df7810938771..c46412002150bd52c533bbabd89008a0c43f0a9f 100644 --- a/experiments/gmsc.jl +++ b/experiments/gmsc.jl @@ -1,8 +1,9 @@ counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size=TEST_SIZE) +nobs = size(counterfactual_data.X, 2) # Default builder: -n_hidden = 128 -activation = Flux.swish +n_hidden = 32 +activation = Flux.relu builder = MLJFlux.@builder Flux.Chain( Dense(n_in, n_hidden, activation), Dense(n_hidden, n_hidden, activation), @@ -24,4 +25,5 @@ run_experiment( opt = Flux.Optimise.Descent(0.05), n_individuals = n_ind, use_variants = false, + min_batch_size = 250, ) \ No newline at end of file