Skip to content

Commit

Permalink
[Bugs] Fix HFCheckpointHook bugs when training deepseekv2 and mixtral…
Browse files Browse the repository at this point in the history
… withou… (#774)

fix HFCheckpointHook bugs when training deepseekv2 and mixtral without shard moe
  • Loading branch information
HIT-cwh authored Jun 17, 2024
1 parent c2328a0 commit bddf85d
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions xtuner/engine/hooks/hf_checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mmengine.runner import FlexibleRunner

from xtuner.registry import BUILDER
from xtuner.utils import SUPPORT_MODELS, get_origin_state_dict
from xtuner.utils import get_origin_state_dict

DATA_BATCH = Optional[Union[dict, tuple, list]]

Expand All @@ -23,6 +23,12 @@ class HFCheckpointHook(Hook):
def __init__(self, out_dir: Optional[Union[str, Path]] = None) -> None:
self.out_dir = out_dir

@staticmethod
def _use_shard_moe(llm):
config = llm.config
moe_implementation = getattr(config, 'moe_implementation', 'origin')
return moe_implementation == 'shard'

def after_run(self, runner) -> None:
assert isinstance(runner,
FlexibleRunner), 'Runner should be `FlexibleRunner`'
Expand Down Expand Up @@ -55,8 +61,7 @@ def after_run(self, runner) -> None:
val = state_dict.pop(k)
state_dict[k[4:]] = val

model_name = type(llm).__name__
if model_name in SUPPORT_MODELS:
if self._use_shard_moe(llm):
print_log('recover the origin state_dict from merged one ...')
state_dict = get_origin_state_dict(state_dict, llm)

Expand Down

0 comments on commit bddf85d

Please sign in to comment.