From f6d7a7bd05b4f42a096d9c1e20c0c986a73e7938 Mon Sep 17 00:00:00 2001 From: Alexander Christensen Date: Sun, 28 Jan 2024 16:14:28 -0600 Subject: [PATCH] some number updates for `auto_device` --- R/auto_device.R | 94 ++++++++++++++++++++++++------------------------- R/rag.R | 6 ++-- 2 files changed, 50 insertions(+), 50 deletions(-) diff --git a/R/auto_device.R b/R/auto_device.R index 503ac49..27853bd 100644 --- a/R/auto_device.R +++ b/R/auto_device.R @@ -3,86 +3,86 @@ # Updated 28.01.2024 auto_device <- function(device, transformer) { - + # Set transformer memory (MB) # Numbers derived from overall memory usage on Alex's 1x A6000 - # Single run of `rag` with each model using 3800 tweets + # Single run of `rag` with each model using 3800 tweets transformer_memory <- round( switch( transformer, - "tinyllama" = 5504, "llama-2" = 5964, + "tinyllama" = 5324, "llama-2" = 5798, "mistral-7b" = 30018, "orca-2" = 29836, "phi-2" = 13594 ), digits = -2 ) - + # First, check for "auto" if(device == "auto"){ - + # Import {torch} torch <- reticulate::import("torch") - + # Check for CUDA if(torch$cuda$is_available()){ - + # Number of GPU devices n_gpu <- torch$cuda$device_count() - + # Branch for number of GPUs if(n_gpu == 1){ - + # Check for available memory (returns MB) gpu_memory <- get_gpu_memory(torch, device_no = 0) - + # Switch to CPU if not enough GPU device <- ifelse(gpu_memory > transformer_memory, "cuda:0", "cpu") - + }else{ - + # Initialize GPU memory gpu_memory <- numeric(length = n_gpu) - + # Loop over and get GPUs for(i in 1:n_gpu){ gpu_memory[i] <- get_gpu_memory(torch, device_no = i - 1) } - + # Check how many GPUs are needed device <- ifelse(gpu_memory[1] > transformer_memory, "cuda:0", "auto") - + } - + }else{device <- "cpu"} - + } - + # Second, check for "cpu" if(device == "cpu"){ - + # Get CPU memory cpu_memory <- get_cpu_memory() - + # Check for not enough memory if(cpu_memory < transformer_memory){ - + # Send error stop( paste0( - "Cannot load the LLM (", transformer_memory, "MB). ", + "Cannot load the LLM (", transformer_memory, "MB). ", "Not enough memory (", cpu_memory, "MB)" ), call. = FALSE ) - + } - + } - + # Send device to user message(paste0("Using device: ", device)) - + # Return device return(device) - + } #' @noRd @@ -90,14 +90,14 @@ auto_device <- function(device, transformer) # Updated 28.01.2024 get_gpu_memory <- function(torch, device_no) { - + # Check for available memory gpu_memory <- capture.output(torch$cuda$get_device_properties(device_no)) - + # Extract memory gpu_memory <- gsub(".*total_memory=", "", gpu_memory) gpu_memory <- gsub("MB,.*", "", gpu_memory) - + # Return value (MB) return(as.numeric(gpu_memory)) @@ -108,47 +108,47 @@ get_gpu_memory <- function(torch, device_no) # Updated 28.01.2024 get_cpu_memory <- function() { - + # Get operating system OS <- tolower(Sys.info()["sysname"]) - + # Branch based on OS if(OS == "windows"){ # Windows - + # Alternative (outputs memory in kB) bytes <- as.numeric( trimws(system("wmic OS get FreePhysicalMemory", intern = TRUE))[2] ) * 1e+03 - + }else if(OS == "linux"){ # Linux - + # Split system information info_split <- strsplit(system("free", intern = TRUE), split = " ") - + # Remove "Mem:" and "Swap:" info_split <- lapply(info_split, function(x){gsub("Mem:", "", x)}) info_split <- lapply(info_split, function(x){gsub("Swap:", "", x)}) - + # Get actual values info_split <- lapply(info_split, function(x){x[x != ""]}) - + # Bind values info_split <- do.call(rbind, info_split[1:2]) - + # Get free values (Linux reports in *kilo*bytes -- thanks, Aleksandar Tomasevic) bytes <- as.numeric(info_split[2, info_split[1,] == "available"]) * 1e+03 - + }else{ # Mac - + # System information system_info <- system("top -l 1 -s 0 | grep PhysMem", intern = TRUE) - + # Get everything after comma unused <- gsub(" .*,", "", system_info) - + # Get values only value <- gsub(" unused.", "", gsub("PhysMem: ", "", unused)) - + # Check for bytes if(grepl("M", value)){ bytes <- as.numeric(gsub("M", "", value)) * 1e+06 @@ -161,10 +161,10 @@ get_cpu_memory <- function() }else if(grepl("T", value)){ # edge case bytes <- as.numeric(gsub("T", "", value)) * 1e+12 } - + } - + # Return value (MB) return(bytes / 1e+06) - + } diff --git a/R/rag.R b/R/rag.R index e4d03f5..873e3c0 100644 --- a/R/rag.R +++ b/R/rag.R @@ -150,9 +150,6 @@ rag <- function( if(missing(response_mode)){ device <- "auto" }else{device <- match.arg(device)} - - # Set device - device <- auto_device(device, transformer) # Run setup for modules setup_modules() @@ -179,6 +176,9 @@ rag <- function( # Get service context if(!exists("service_context", envir = as.environment(envir))){ + # Set device + device <- auto_device(device, transformer) + # Set up service context service_context <- switch( transformer,