From 24a22278e163ab838f5cd0597a812e6b1fa5a8a7 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 15 Mar 2024 12:20:37 +0000 Subject: [PATCH] Update chat templates to use the new API (#15) --- .../models/cohere/tokenization_cohere_fast.py | 111 +++--------------- 1 file changed, 18 insertions(+), 93 deletions(-) diff --git a/src/transformers/models/cohere/tokenization_cohere_fast.py b/src/transformers/models/cohere/tokenization_cohere_fast.py index 66efbddab6a73a..e733a6dfd09541 100644 --- a/src/transformers/models/cohere/tokenization_cohere_fast.py +++ b/src/transformers/models/cohere/tokenization_cohere_fast.py @@ -16,12 +16,12 @@ # This file is based on the tokenization_llama_fast.py file in transformers import pickle -from typing import Dict, List, Literal, Optional, Union +from typing import Dict, List, Literal, Union from tokenizers import processors from ...pipelines.conversational import Conversation -from ...tokenization_utils_base import BatchEncoding, TensorType +from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import logging from ...utils.versions import require_version @@ -253,7 +253,7 @@ def default_chat_template(self): "your model, please set `tokenizer.chat_template` to an appropriate template. " "See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n" ) - template = ( + default_template = ( "{{ bos_token }}" "{% if messages[0]['role'] == 'system' %}" "{% set loop_messages = messages[1:] %}" # Extract system message if it's present @@ -283,15 +283,13 @@ def default_chat_template(self): "{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}" "{% endif %}" ) - template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false") + default_template = default_template.replace( + "USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false" + ) default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") - template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) - - return template + default_template = default_template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) - @property - def default_tool_use_template(self): - template = ( + tool_use_template = ( "{{ bos_token }}" "{% if messages[0]['role'] == 'system' %}" "{% set loop_messages = messages[1:] %}" # Extract system message if it's present @@ -360,13 +358,10 @@ def default_tool_use_template(self): "{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}" "{% endif %}" ) - default_message = DEFAULT_RAG_PREAMBLE.replace("\n", "\\n").replace("'", "\\'") - template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) - return template + default_tool_message = DEFAULT_RAG_PREAMBLE.replace("\n", "\\n").replace("'", "\\'") + tool_use_template = tool_use_template.replace("DEFAULT_SYSTEM_MESSAGE", default_tool_message) - @property - def default_grounded_generation_template(self): - template = ( + rag_template = ( "{{ bos_token }}" "{% if messages[0]['role'] == 'system' %}" "{% set loop_messages = messages[1:] %}" # Extract system message if it's present @@ -417,66 +412,15 @@ def default_grounded_generation_template(self): "{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}" "{% endif %}" ) - default_message = DEFAULT_RAG_PREAMBLE.replace("\n", "\\n").replace("'", "\\'") - template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) - return template - - def _apply_template_with_arguments( - self, - conversation: Union[List[Dict[str, str]], "Conversation"], - template: Optional[str] = None, - add_generation_prompt: bool = False, - tokenize: bool = True, - padding: bool = False, - truncation: bool = False, - max_length: Optional[int] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - return_dict: bool = False, - **kwargs, - ) -> Union[str, List[int]]: - """Just tokenization_utils_base.apply_chat_template, but modified so that the jinjia template can take kwargs""" - if hasattr(conversation, "messages"): - # Indicates it's a Conversation object - conversation = conversation.messages - - # Compilation function uses a cache to avoid recompiling the same template - compiled_template = self._compile_jinja_template(template) - - rendered = compiled_template.render( - messages=conversation, add_generation_prompt=add_generation_prompt, **kwargs, **self.special_tokens_map - ) + default_rag_message = DEFAULT_RAG_PREAMBLE.replace("\n", "\\n").replace("'", "\\'") + rag_template = rag_template.replace("DEFAULT_SYSTEM_MESSAGE", default_rag_message) - if padding is True: - padding = "max_length" # There's only one sequence here, so "longest" makes no sense - if tokenize: - if return_dict: - return self( - rendered, - padding=padding, - truncation=truncation, - max_length=max_length, - add_special_tokens=False, - return_tensors=return_tensors, - **kwargs, - ) - else: - return self.encode( - rendered, - padding=padding, - truncation=truncation, - max_length=max_length, - add_special_tokens=False, - return_tensors=return_tensors, - **kwargs, - ) - else: - return rendered + return {"default": default_template, "tool_use": tool_use_template, "rag": rag_template} def apply_tool_use_template( self, conversation: Union[List[Dict[str, str]], "Conversation"], tools: List[Dict], - tool_use_template: Optional[str] = None, **kwargs, ) -> Union[str, List[int]]: """Create a Command-R tool-use prompt. @@ -508,8 +452,6 @@ def apply_tool_use_template( * description (str): The description of the parameter. * type (str): the type of the parameter - most effective for python builtin data types, such as 'str', 'bool' * required: boolean: Denotes whether the parameter is always present (required) or not. Defaults to not required. - tool_use_template (str, *optional*): A Jinja template to use for this conversion. If - this is not passed, the model's default chat template will be used instead. add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate the start of an assistant message. This is useful when you want to generate a response from the model. Note that this argument will be passed to the chat template, and so it must be supported in the @@ -623,17 +565,10 @@ def directly_answer() -> List[Dict]: ] ``` """ - # priority: `tool_use_template` argument > `tokenizer.tool_use_template` > `tokenizer.default_tool_use_template` - if tool_use_template is None: - if self.tool_use_template is not None: - tool_use_template = self.tool_use_template - else: - tool_use_template = self.default_tool_use_template - - return self._apply_template_with_arguments( + return self.apply_chat_template( conversation, + chat_template="tool_use", tools=tools, - template=tool_use_template, **kwargs, ) @@ -642,7 +577,6 @@ def apply_grounded_generation_template( conversation: Union[List[Dict[str, str]], "Conversation"], documents: List[Dict], citation_mode: Literal["fast", "accurate"] = "accurate", - grounded_generation_template: Optional[str] = None, **kwargs, ) -> Union[str, List[int]]: """Create a Command-R grounded generation (aka RAG) prompt. @@ -666,8 +600,6 @@ def apply_grounded_generation_template( citation_mode: either "accurate" (prompt the model to generate an answer first, then rewrite it with citation spans in) or "fast", where the prompt instructs the model to generate an answer with citations in directly. The former has higher quality citations, the latter requires fewer tokens to be generated. - grounded_generation_template (str, *optional*): A Jinja template to use for this conversion. If - this is not passed, the model's default grounded_generation_template template will be used instead. add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate the start of an assistant message. This is useful when you want to generate a response from the model. Note that this argument will be passed to the chat template, and so it must be supported in the @@ -748,17 +680,10 @@ def apply_grounded_generation_template( Answer: The Emperor Penguin is the tallest or biggest penguin in the world. It is a bird that lives only in Antarctica and grows to a height of around 122 centimetres. Grounded answer: The Emperor Penguin is the tallest or biggest penguin in the world. It is a bird that lives only in Antarctica and grows to a height of around 122 centimetres. """ - # priority: `grounded_generation_template` argument > `tokenizer.grounded_generation_template` > `tokenizer.default_grounded_generation_template` - if grounded_generation_template is None: - if self.grounded_generation_template is not None: - grounded_generation_template = self.grounded_generation_template - else: - grounded_generation_template = self.default_grounded_generation_template - - return self._apply_template_with_arguments( + return self.apply_chat_template( conversation, + chat_template="rag", documents=documents, - template=grounded_generation_template, citation_mode=citation_mode, **kwargs, )