Skip to content

Commit

Permalink
Rebased.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jun 25, 2024
1 parent 5a9fee6 commit bac7903
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 26 deletions.
4 changes: 3 additions & 1 deletion benchmark/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing::info!("Downloading tokenizer");

// Parse Huggingface hub token
let auth_token = std::env::var("HF_TOKEN").or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")).ok();
let auth_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();

// Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime
Expand Down
3 changes: 1 addition & 2 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ fn num_cuda_devices() -> Option<usize> {
Err(_) => match env::var("NVIDIA_VISIBLE_DEVICES") {
Ok(devices) => devices,
Err(_) => env::var("ZE_AFFINITY_MASK").ok()?,
}
},
};
let n_devices = devices.split(',').count();
Some(n_devices)
Expand Down Expand Up @@ -1225,7 +1225,6 @@ fn spawn_webserver(
router_args.push("--otlp-service-name".to_string());
router_args.push(otlp_service_name);


// CORS origins
for origin in args.cors_allow_origin.into_iter() {
router_args.push("--cors-allow-origin".to_string());
Expand Down
12 changes: 1 addition & 11 deletions router/src/infer/v2/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,15 @@ impl SchedulerV2 {
speculate: u32,
generation_health: Arc<AtomicBool>,
) -> Self {
<<<<<<< HEAD:router/src/infer/v2/scheduler.rs
let queue = Queue::new(requires_padding, 16, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new());
=======
// Infer shared state
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
} else {
false
};
let block_size = if flashdecoding { 256 } else { 16 };
let block_size = std::env::var("BLOCK_SIZE")
.map(|b| b.parse().unwrap_or(block_size))
.unwrap_or(block_size);
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});
>>>>>>> Using flash decoding:router/src/infer.rs
let batching_task_notifier = Arc::new(Notify::new());

// Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(
Expand Down
8 changes: 7 additions & 1 deletion router/src/infer/v3/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,15 @@ impl SchedulerV3 {
speculate: u32,
generation_health: Arc<AtomicBool>,
) -> Self {
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
} else {
false
};
let block_size = if flashdecoding { 256 } else { 16 };
let queue = Queue::new(
requires_padding,
16,
block_size,
window_size,
speculate,
max_batch_total_tokens,
Expand Down
4 changes: 3 additions & 1 deletion router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ async fn main() -> Result<(), RouterError> {
});

// Parse Huggingface hub token
let authorization_token = std::env::var("HF_TOKEN").or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN")).ok();
let authorization_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();

// Tokenizer instance
// This will only be used to validate payloads
Expand Down
3 changes: 1 addition & 2 deletions server/text_generation_server/models/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE

BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
# Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None

Expand Down
32 changes: 27 additions & 5 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
import text_generation_server.models.globals as tgi_globals
from text_generation_server.models.globals import (
MODEL_ID,
FLASH_DECODING,
MEM_POOL,
CUDA_GRAPHS,
)
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION

Expand All @@ -40,7 +44,9 @@

tracer = trace.get_tracer(__name__)

BLOCK_SIZE: int = 256 if os.getenv("FLASH_DECODING", "").lower() in {"1", "true"} else 16
BLOCK_SIZE: int = (
256 if os.getenv("FLASH_DECODING", "").lower() in {"1", "true"} else 16
)

# Will be set in init
SLIDING_WINDOW: Optional[int] = None
Expand Down Expand Up @@ -770,7 +776,23 @@ def init_kv_cache(
element_size = torch.tensor([], dtype=dtype).element_size()
x = BLOCK_SIZE // element_size

if SYSTEM == "ipex" and device == torch.device("cpu"):
if FLASH_DECODING:
self.kv_cache = [
(
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
elif SYSTEM == "ipex" and device == torch.device("cpu"):
self.kv_cache = [
(
torch.empty(
Expand Down Expand Up @@ -934,7 +956,7 @@ def warmup(self, batch: FlashCausalLMBatch):

tunableop_filepath = os.path.join(
HUGGINGFACE_HUB_CACHE,
f"tunableop_{tgi_globals.MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
)

logger.info(
Expand Down
7 changes: 4 additions & 3 deletions server/text_generation_server/models/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@

MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli
cuda_graphs = os.getenv("CUDA_GRAPHS")
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
if FLASH_DECODING:
logger.info("Using FLASH_DECODING")


cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None:
try:
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
Expand All @@ -17,8 +20,6 @@
)
else:
cuda_graphs = None


# sorting the cuda graphs in descending order helps reduce the
# memory impact and results in less memory usage
if cuda_graphs is not None:
Expand Down

0 comments on commit bac7903

Please sign in to comment.