diff --git a/experiments/models/default_models.jl b/experiments/models/default_models.jl index 53319cd9e65c7b8e9b8f5c9e21c863aecdc4d5c6..134151593e604a65fb55360f857ea64753122ea3 100644 --- a/experiments/models/default_models.jl +++ b/experiments/models/default_models.jl @@ -49,6 +49,7 @@ function default_models(; batch_size=batch_size, finaliser=finaliser, loss=loss, + acceleration=CUDALibs(), ) # Deep Ensemble: @@ -65,7 +66,8 @@ function default_models(; jem_training_params=( α=α, verbosity=verbosity, ), - sampling_steps=sampling_steps + sampling_steps=sampling_steps, + acceleration=CUDALibs(), ) # Deep Ensemble of Joint Energy Models: