Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Oct 31, 2024
1 parent c048885 commit cf73f3d
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 8 deletions.
1 change: 0 additions & 1 deletion swift/llm/dataset/preprocess/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Set, Union

import numpy as np
from datasets import Dataset as HfDataset
from datasets import IterableDataset as HfIterableDataset

Expand Down
13 changes: 12 additions & 1 deletion swift/llm/model/model_meta.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
import itertools
from dataclasses import dataclass, field
from typing import Optional, List
from typing import Callable, List, Optional, Tuple, TypeVar, Union

from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.utils.versions import require_version

from swift.utils import get_logger

logger = get_logger()
GetModelTokenizerFunction = Callable[..., Tuple[Optional[PreTrainedModel], PreTrainedTokenizerBase]]


@dataclass
class Model:
Expand Down Expand Up @@ -52,6 +62,7 @@ class ModelMeta:
support_megatron: bool = False

def get_matched_model_groups(self, model_dir: str) -> List[ModelGroup]:
from .utils import HfConfigFactory
model_name = HfConfigFactory._get_model_name(model_dir).lower()
res = []
seen = set()
Expand Down
9 changes: 3 additions & 6 deletions swift/llm/model/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@
PreTrainedModel, PreTrainedTokenizerBase)
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils import is_torch_bf16_gpu_available, is_torch_cuda_available, is_torch_npu_available
from transformers.utils.versions import require_version

from swift.utils import get_dist_setting, get_logger, is_ddp_plus_mp, is_dist, is_unsloth_available, use_torchacc
from .model_meta import ModelMeta
from .utils import AttnImpl, HfConfigFactory, safe_snapshot_download

MODEL_MAPPING: Dict[str, Dict[str, Any]] = {}

ARCH_MAPPING: Optional[Dict[str, Dict[str, List[str]]]] = None

GetModelTokenizerFunction = Callable[..., Tuple[Optional[PreTrainedModel], PreTrainedTokenizerBase]]
logger = get_logger()


Expand Down Expand Up @@ -237,13 +236,11 @@ def get_model_tokenizer(
if download_model is None:
download_model = load_model
# download config.json
model_dir = safe_snapshot_download(
model_id_or_path, revision=revision, download_model=False, use_hf=use_hf)
model_dir = safe_snapshot_download(model_id_or_path, revision=revision, download_model=False, use_hf=use_hf)
model_info = HfConfigFactory.get_model_info(model_dir)

if download_model:
safe_snapshot_download(
model_id_or_path, revision=revision, download_model=download_model, use_hf=use_hf)
safe_snapshot_download(model_id_or_path, revision=revision, download_model=download_model, use_hf=use_hf)

if not use_torchacc() and device_map is None:
device_map = get_default_device_map()
Expand Down
3 changes: 3 additions & 0 deletions swift/llm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@

from swift.hub import HFHub, MSHub, default_hub
from swift.utils import deep_getattr, get_logger, is_dist, is_dist_ta, safe_ddp_context
from .model_meta import ModelMeta

logger = get_logger()


@dataclass
class ModelInfo:
model_meta: ModelMeta
Expand All @@ -28,6 +30,7 @@ class ModelInfo:
quant_bits: int
config: PretrainedConfig


class HfConfigFactory:
"""This class is used to read config from config.json(maybe params.json also)"""

Expand Down

0 comments on commit cf73f3d

Please sign in to comment.