diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index f39ed96009..7796e90f97 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -17,9 +17,7 @@ def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: self.model = dispatch_model(self.model) self.template = get_template(data_args.template) self.source_prefix = data_args.source_prefix - self.stop_ids = [ - self.tokenizer.encode(word, add_special_tokens=False)[0] for word in self.template.stop_words - ] + self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words) self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words)) self.model.generate = MethodType(PreTrainedModel.generate, self.model) # a monkey fix for qwen model diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index ee33218c9b..eed7892bb2 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -6,13 +6,14 @@ AutoConfig, AutoModelForCausalLM, AutoTokenizer, - BitsAndBytesConfig + BitsAndBytesConfig, + PretrainedConfig, + PreTrainedModel, + PreTrainedTokenizerBase ) from transformers.utils import check_min_version from transformers.utils.versions import require_version from transformers.deepspeed import is_deepspeed_zero3_enabled -from transformers.modeling_utils import PretrainedConfig, PreTrainedModel -from transformers.tokenization_utils import PreTrainedTokenizerBase from trl import AutoModelForCausalLMWithValueHead from llmtuner.extras.logging import reset_logging, get_logger @@ -22,6 +23,7 @@ from llmtuner.tuner.core.adapter import init_adapter if TYPE_CHECKING: + from transformers import PreTrainedTokenizer from llmtuner.hparams import ModelArguments @@ -40,7 +42,7 @@ def load_model_and_tokenizer( finetuning_args: "FinetuningArguments", is_trainable: Optional[bool] = False, stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft" -) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: +) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]: r""" Loads pretrained model and tokenizer.