From 02ac45131f3bfe8b2155d3cd00e3a8591eb65655 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 27 Jun 2024 13:33:35 +0000 Subject: [PATCH] some cleaning --- .../custom_modeling/flash_llama_modeling.py | 31 +++---------------- .../models/flash_causal_lm.py | 17 ---------- 2 files changed, 4 insertions(+), 44 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 3f08c810254..8cb8c0a9713 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -111,7 +111,6 @@ def __init__( prefix: str, config, weights, - layer_idx, ): super().__init__() self.num_heads = config.num_attention_heads @@ -144,7 +143,6 @@ def __init__( self.query_key_value = load_attention(config, prefix, weights, index) self.index = index - self.layer_idx = layer_idx o_proj = TensorParallelRowLinear.load( config, @@ -165,8 +163,6 @@ def __init__( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device ).repeat_interleave(self.num_groups) - self.step = 0 - def forward( self, hidden_states, @@ -198,18 +194,6 @@ def forward( # output tensor attn_output = torch.empty_like(query) - if self.layer_idx < 4: - torch.save(query, f"query_states_step{self.step}_layer{self.layer_idx}.pt") - if cu_seqlen_prefill is not None: - torch.save( - torch.select(kv, dim=1, index=0), - f"key_states_step{self.step}_layer{self.layer_idx}.pt", - ) - torch.save( - torch.select(kv, dim=1, index=1), - f"value_states_step{self.step}_layer{self.layer_idx}.pt", - ) - # Prefill if cu_seqlen_prefill is not None: # flash attention @@ -236,14 +220,9 @@ def forward( max_s, ) - attn_output = attn_output.view(-1, self.num_heads * self.head_size) - if self.layer_idx < 4: - torch.save( - attn_output, f"attn_output_step{self.step}_layer{self.layer_idx}.pt" - ) - - self.step += 1 - return self.o_proj(attn_output, adapter_data) + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) class LlamaMLP(nn.Module): @@ -342,14 +321,13 @@ def forward(self, hidden_states, adapter_data): class FlashLlamaLayer(nn.Module): - def __init__(self, index, prefix, config, weights, layer_idx): + def __init__(self, index, prefix, config, weights): super().__init__() self.self_attn = FlashLlamaAttention( index=index, prefix=f"{prefix}.self_attn", config=config, weights=weights, - layer_idx=layer_idx, ) self.mlp = LlamaMLP( prefix=f"{prefix}.mlp", config=config, weights=weights, index=index @@ -422,7 +400,6 @@ def __init__(self, prefix, config, weights): ), config=config, weights=weights, - layer_idx=layer_id, ) for layer_id in range(config.num_hidden_layers) ] diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a19057944d9..f7678762592 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1149,23 +1149,6 @@ def forward( cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - logger.info(f"input_ids {input_ids} {input_ids.shape}") - logger.info(f"position_ids {position_ids} {position_ids.shape}") - logger.info( - f"cu_seqlen_prefill {cu_seqlen_prefill} {cu_seqlen_prefill.shape if cu_seqlen_prefill is not None else 'NONE'}" - ) - logger.info( - f"kv_cache {type(kv_cache)}, len={len(kv_cache)}, {len(kv_cache[0])}, shape={kv_cache[0][0].shape}" - ) - logger.info( - f"block_tables {type(block_tables)} {block_tables.shape} {block_tables}" - ) - logger.info(f"slots {type(slots)} {slots.shape} {slots}") - logger.info(f"input_lengths {input_lengths}") - logger.info(f"max_s {max_s}") - logger.info(f"prefill_cache_indices {batch.prefill_cache_indices}") - logger.info(f"lm_head_indices {lm_head_indices}") - logger.info(f"adapter_data {adapter_data}") logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids,