Skip to content

Commit

Permalink
initial support for GPU for #55
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Nov 2, 2023
1 parent 65d21b6 commit fe3c827
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 5 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export(brulee_logistic_reg)
export(brulee_mlp)
export(brulee_multinomial_reg)
export(coef)
export(guess_brulee_device)
export(matrix_to_dataset)
export(schedule_cyclic)
export(schedule_decay_expo)
Expand Down
2 changes: 1 addition & 1 deletion R/convert_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#' matrix_to_dataset(as.matrix(mtcars[, -1]), mtcars$mpg)
#' }
#' @export
matrix_to_dataset <- function(x, y, device) {
matrix_to_dataset <- function(x, y, device = "cpu") {
x <- torch::torch_tensor(x, device = device)
if (is.factor(y)) {
y <- as.numeric(y)
Expand Down
3 changes: 2 additions & 1 deletion R/device.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#' Determine an appropriate computational device for torch
#'
#' Uses \pkg{torch} functions to determine if there is a GPU available for use.
#' @return A character string, one of: `"cpu"`, `"cuda"`, or `"mps"`.
#' @examplesI
#' @examples
#' guess_brulee_device()
#' @export
guess_brulee_device <- function() {
Expand Down
9 changes: 7 additions & 2 deletions R/mlp-fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@
#' @param stop_iter A non-negative integer for how many iterations with no
#' improvement before stopping.
#' @param verbose A logical that prints out the iteration history.
#' @param device A character string or `NULL` (if you want it to guess). Possible
#' values are `"cpu"`, `"cuda"`, `"mps"`, `"auto"`. The last value uses
#' [guess_brulee_device()].
#' @param ... Options to pass to the learning rate schedulers via
#' [set_learn_rate()]. For example, the `reduction` or `steps` arguments to
#' [schedule_step()] could be passed here.
Expand Down Expand Up @@ -465,9 +468,11 @@ brulee_mlp_bridge <- function(processed, epochs, hidden_units, activation,
check_logical(verbose, single = TRUE, fn = f_nm)
check_character(activation, single = FALSE, fn = f_nm)


# ------------------------------------------------------------------------------

if (is.null(device)) {
device <- rlang::arg_match(device, c("cpu", "auto", "cuda", "mps"))
if (device == "auto") {
device <- guess_brulee_device()
}

Expand Down Expand Up @@ -661,7 +666,7 @@ mlp_fit_imp <-
dl <- torch::dataloader(ds, batch_size = batch_size)

if (validation > 0) {
ds_val <- brulee::matrix_to_dataset(x_val, y_val)
ds_val <- brulee::matrix_to_dataset(x_val, y_val, device = device)
dl_val <- torch::dataloader(ds_val)
}

Expand Down
7 changes: 7 additions & 0 deletions man/brulee_mlp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions man/guess_brulee_device.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/matrix_to_dataset.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit fe3c827

Please sign in to comment.