Skip to content

Commit

Permalink
[Bugs] fix dispatch bugs when model not in LOWEST_TRANSFORMERS_VERSION (
Browse files Browse the repository at this point in the history
#802)

* fix dispatch bugs when model not in LOWEST_TRANSFORMERS_VERSION

* move rope_theta
  • Loading branch information
HIT-cwh authored Jul 9, 2024
1 parent 44749c2 commit 9c28b40
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions xtuner/model/modules/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,14 @@ def replace_rote(model):
from mmengine import print_log
print_log = log_once(print_log)

assert hasattr(model.config, 'rope_theta'), \
'`rope_theta` should be in the model config.'
rope_theta = model.config.rope_theta

def traverse(module):
for name, child in module.named_children():
cls_name = type(child).__name__
if cls_name in ROTE_DISPATCH_MAPPING:
assert hasattr(model.config, 'rope_theta'), \
'`rope_theta` should be in the model config.'
rope_theta = model.config.rope_theta

rote = ROTE_DISPATCH_MAPPING[cls_name]
rote = rote.build()
print_log(f'replace {cls_name}', 'current')
Expand All @@ -258,10 +258,11 @@ def check(model_name):
# a walkaround for reward model
model_name = model_name[:-5] + 'ForCausalLM'
msg = '{} requires transformers version at least {}, but got {}'
assert TRANSFORMERS_VERSION >= LOWEST_TRANSFORMERS_VERSION[
model_name], msg.format(model_name,
LOWEST_TRANSFORMERS_VERSION[model_name],
TRANSFORMERS_VERSION)
if model_name in LOWEST_TRANSFORMERS_VERSION:
assert TRANSFORMERS_VERSION >= LOWEST_TRANSFORMERS_VERSION[
model_name], msg.format(
model_name, LOWEST_TRANSFORMERS_VERSION[model_name],
TRANSFORMERS_VERSION)

check(type(model).__name__)
if use_varlen_attn:
Expand Down

0 comments on commit 9c28b40

Please sign in to comment.