Skip to content

Commit

Permalink
[refactor] move cache_manager out of DiTRuntimeState (xdit-project#237)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear committed Oct 25, 2024
1 parent cf4aad4 commit 1885659
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 14 deletions.
9 changes: 8 additions & 1 deletion xfuser/core/distributed/runtime_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig):
* pipeline.transformer.config.attention_head_dim,
)
self.pipeline_comm_extra_tensors_info = []
self.cache_manager = CacheManager()
# self.cache_manager = CacheManager()

def set_input_parameters(
self,
Expand Down Expand Up @@ -366,6 +366,7 @@ def _reset_recv_skip_buffer(self, num_blocks_per_stage):
# _RUNTIME: Optional[RuntimeState] = None
# TODO: change to RuntimeState after implementing the unet
_RUNTIME: Optional[DiTRuntimeState] = None
_CACHE_MGR = CacheManager()


def runtime_state_is_initialized():
Expand All @@ -385,3 +386,9 @@ def initialize_runtime_state(pipeline: DiffusionPipeline, engine_config: EngineC
)
if hasattr(pipeline, "transformer"):
_RUNTIME = DiTRuntimeState(pipeline=pipeline, config=engine_config)


def get_cache_manager():
global _CACHE_MGR
assert _CACHE_MGR is not None, "Cache manager has not been initialized."
return _CACHE_MGR
4 changes: 2 additions & 2 deletions xfuser/core/long_ctx_attention/ring/ring_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from yunchang.ring.utils import RingComm, update_out_and_lse
from yunchang.ring.ring_flash_attn import RingFlashAttnFunc

from xfuser.core.distributed.runtime_state import get_runtime_state
from xfuser.core.distributed.runtime_state import get_cache_manager, get_runtime_state


def ring_flash_attn_forward(
Expand Down Expand Up @@ -42,7 +42,7 @@ def ring_flash_attn_forward(
next_k, next_v = None, None

if attn_layer is not None:
k, v = get_runtime_state().cache_manager.update_and_get_kv_cache(
k, v = get_cache_manager().update_and_get_kv_cache(
new_kv=[k, v],
layer=attn_layer,
slice_dim=1,
Expand Down
4 changes: 2 additions & 2 deletions xfuser/core/long_ctx_attention/ulysses/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from yunchang.comm.all_to_all import SeqAllToAll4D
from yunchang.ulysses.attn_layer import torch_attn

from xfuser.core.distributed.runtime_state import get_runtime_state
from xfuser.core.distributed.runtime_state import get_cache_manager, get_runtime_state


class xFuserUlyssesAttention(UlyssesAttention):
Expand Down Expand Up @@ -120,7 +120,7 @@ def forward(
)

if self.use_kv_cache:
k, v = get_runtime_state().cache_manager.update_and_get_kv_cache(
k, v = get_cache_manager().update_and_get_kv_cache(
new_kv=[k, v],
layer=attn,
slice_dim=1,
Expand Down
12 changes: 6 additions & 6 deletions xfuser/model_executor/layers/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
get_sequence_parallel_rank,
get_sp_group,
)
from xfuser.core.distributed.runtime_state import get_runtime_state
from xfuser.core.distributed.runtime_state import get_cache_manager, get_runtime_state
from xfuser.model_executor.layers import xFuserLayerBaseWrapper
from xfuser.model_executor.layers import xFuserLayerWrappersRegister
from xfuser.logger import init_logger
Expand Down Expand Up @@ -266,7 +266,7 @@ def __call__(

#! ---------------------------------------- KV CACHE ----------------------------------------
if not self.use_long_ctx_attn_kvcache:
key, value = get_runtime_state().cache_manager.update_and_get_kv_cache(
key, value = get_cache_manager().update_and_get_kv_cache(
new_kv=[key, value],
layer=attn,
slice_dim=2,
Expand Down Expand Up @@ -417,7 +417,7 @@ def __call__(

#! ---------------------------------------- KV CACHE ----------------------------------------
if not self.use_long_ctx_attn_kvcache:
key, value = get_runtime_state().cache_manager.update_and_get_kv_cache(
key, value = get_cache_manager().update_and_get_kv_cache(
new_kv=[key, value],
layer=attn,
slice_dim=1,
Expand Down Expand Up @@ -645,7 +645,7 @@ def __call__(
encoder_hidden_states_value_proj, value = value.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=2
)
key, value = get_runtime_state().cache_manager.update_and_get_kv_cache(
key, value = get_cache_manager().update_and_get_kv_cache(
new_kv=[key, value],
layer=attn,
slice_dim=2,
Expand Down Expand Up @@ -815,7 +815,7 @@ def __call__(

#! ---------------------------------------- KV CACHE ----------------------------------------
if not self.use_long_ctx_attn_kvcache:
key, value = get_runtime_state().cache_manager.update_and_get_kv_cache(
key, value = get_cache_manager().update_and_get_kv_cache(
new_kv=[key, value],
layer=attn,
slice_dim=2,
Expand Down Expand Up @@ -983,7 +983,7 @@ def __call__(

#! ---------------------------------------- KV CACHE ----------------------------------------
if not self.use_long_ctx_attn_kvcache:
key, value = get_runtime_state().cache_manager.update_and_get_kv_cache(
key, value = get_cache_manager().update_and_get_kv_cache(
new_kv=[key, value],
layer=attn,
slice_dim=2,
Expand Down
6 changes: 3 additions & 3 deletions xfuser/model_executor/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn as nn
from xfuser.config import InputConfig, ParallelConfig, RuntimeConfig
from xfuser.core.distributed.parallel_state import get_sequence_parallel_world_size
from xfuser.core.distributed.runtime_state import get_runtime_state
from xfuser.core.distributed.runtime_state import get_cache_manager, get_runtime_state
from xfuser.model_executor.base_wrapper import xFuserBaseWrapper
from xfuser.model_executor.layers import *
from xfuser.core.distributed import get_world_group
Expand Down Expand Up @@ -127,11 +127,11 @@ def _register_cache(
if get_sequence_parallel_world_size() == 1 or not getattr(
layer.processor, "use_long_ctx_attn_kvcache", False
):
get_runtime_state().cache_manager.register_cache_entry(
get_cache_manager().register_cache_entry(
layer, layer_type="attn", cache_type="naive_cache"
)
else:
get_runtime_state().cache_manager.register_cache_entry(
get_cache_manager().register_cache_entry(
layer,
layer_type="attn",
cache_type="sequence_parallel_attn_cache",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
logger = init_logger(__name__)


# adapted from
# https://github.com/huggingface/diffusers/blob/b5f591fea843cb4bf1932bd94d1db5d5eebe3298/src/diffusers/models/transformers/hunyuan_transformer_2d.py#L203
@xFuserTransformerWrappersRegister.register(HunyuanDiT2DModel)
class xFuserHunyuanDiT2DWrapper(xFuserTransformerBaseWrapper):
def __init__(
Expand Down

0 comments on commit 1885659

Please sign in to comment.