Skip to content

Commit

Permalink
Update chat templates to use the new API (huggingface#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 authored Mar 15, 2024
1 parent bb7f728 commit 24a2227
Showing 1 changed file with 18 additions and 93 deletions.
111 changes: 18 additions & 93 deletions src/transformers/models/cohere/tokenization_cohere_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
# This file is based on the tokenization_llama_fast.py file in transformers

import pickle
from typing import Dict, List, Literal, Optional, Union
from typing import Dict, List, Literal, Union

from tokenizers import processors

from ...pipelines.conversational import Conversation
from ...tokenization_utils_base import BatchEncoding, TensorType
from ...tokenization_utils_base import BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
from ...utils.versions import require_version
Expand Down Expand Up @@ -253,7 +253,7 @@ def default_chat_template(self):
"your model, please set `tokenizer.chat_template` to an appropriate template. "
"See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n"
)
template = (
default_template = (
"{{ bos_token }}"
"{% if messages[0]['role'] == 'system' %}"
"{% set loop_messages = messages[1:] %}" # Extract system message if it's present
Expand Down Expand Up @@ -283,15 +283,13 @@ def default_chat_template(self):
"{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}"
"{% endif %}"
)
template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
default_template = default_template.replace(
"USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false"
)
default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)

return template
default_template = default_template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)

@property
def default_tool_use_template(self):
template = (
tool_use_template = (
"{{ bos_token }}"
"{% if messages[0]['role'] == 'system' %}"
"{% set loop_messages = messages[1:] %}" # Extract system message if it's present
Expand Down Expand Up @@ -360,13 +358,10 @@ def default_tool_use_template(self):
"{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}"
"{% endif %}"
)
default_message = DEFAULT_RAG_PREAMBLE.replace("\n", "\\n").replace("'", "\\'")
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
return template
default_tool_message = DEFAULT_RAG_PREAMBLE.replace("\n", "\\n").replace("'", "\\'")
tool_use_template = tool_use_template.replace("DEFAULT_SYSTEM_MESSAGE", default_tool_message)

@property
def default_grounded_generation_template(self):
template = (
rag_template = (
"{{ bos_token }}"
"{% if messages[0]['role'] == 'system' %}"
"{% set loop_messages = messages[1:] %}" # Extract system message if it's present
Expand Down Expand Up @@ -417,66 +412,15 @@ def default_grounded_generation_template(self):
"{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}"
"{% endif %}"
)
default_message = DEFAULT_RAG_PREAMBLE.replace("\n", "\\n").replace("'", "\\'")
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
return template

def _apply_template_with_arguments(
self,
conversation: Union[List[Dict[str, str]], "Conversation"],
template: Optional[str] = None,
add_generation_prompt: bool = False,
tokenize: bool = True,
padding: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_dict: bool = False,
**kwargs,
) -> Union[str, List[int]]:
"""Just tokenization_utils_base.apply_chat_template, but modified so that the jinjia template can take kwargs"""
if hasattr(conversation, "messages"):
# Indicates it's a Conversation object
conversation = conversation.messages

# Compilation function uses a cache to avoid recompiling the same template
compiled_template = self._compile_jinja_template(template)

rendered = compiled_template.render(
messages=conversation, add_generation_prompt=add_generation_prompt, **kwargs, **self.special_tokens_map
)
default_rag_message = DEFAULT_RAG_PREAMBLE.replace("\n", "\\n").replace("'", "\\'")
rag_template = rag_template.replace("DEFAULT_SYSTEM_MESSAGE", default_rag_message)

if padding is True:
padding = "max_length" # There's only one sequence here, so "longest" makes no sense
if tokenize:
if return_dict:
return self(
rendered,
padding=padding,
truncation=truncation,
max_length=max_length,
add_special_tokens=False,
return_tensors=return_tensors,
**kwargs,
)
else:
return self.encode(
rendered,
padding=padding,
truncation=truncation,
max_length=max_length,
add_special_tokens=False,
return_tensors=return_tensors,
**kwargs,
)
else:
return rendered
return {"default": default_template, "tool_use": tool_use_template, "rag": rag_template}

def apply_tool_use_template(
self,
conversation: Union[List[Dict[str, str]], "Conversation"],
tools: List[Dict],
tool_use_template: Optional[str] = None,
**kwargs,
) -> Union[str, List[int]]:
"""Create a Command-R tool-use prompt.
Expand Down Expand Up @@ -508,8 +452,6 @@ def apply_tool_use_template(
* description (str): The description of the parameter.
* type (str): the type of the parameter - most effective for python builtin data types, such as 'str', 'bool'
* required: boolean: Denotes whether the parameter is always present (required) or not. Defaults to not required.
tool_use_template (str, *optional*): A Jinja template to use for this conversion. If
this is not passed, the model's default chat template will be used instead.
add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate
the start of an assistant message. This is useful when you want to generate a response from the model.
Note that this argument will be passed to the chat template, and so it must be supported in the
Expand Down Expand Up @@ -623,17 +565,10 @@ def directly_answer() -> List[Dict]:
]
```
"""
# priority: `tool_use_template` argument > `tokenizer.tool_use_template` > `tokenizer.default_tool_use_template`
if tool_use_template is None:
if self.tool_use_template is not None:
tool_use_template = self.tool_use_template
else:
tool_use_template = self.default_tool_use_template

return self._apply_template_with_arguments(
return self.apply_chat_template(
conversation,
chat_template="tool_use",
tools=tools,
template=tool_use_template,
**kwargs,
)

Expand All @@ -642,7 +577,6 @@ def apply_grounded_generation_template(
conversation: Union[List[Dict[str, str]], "Conversation"],
documents: List[Dict],
citation_mode: Literal["fast", "accurate"] = "accurate",
grounded_generation_template: Optional[str] = None,
**kwargs,
) -> Union[str, List[int]]:
"""Create a Command-R grounded generation (aka RAG) prompt.
Expand All @@ -666,8 +600,6 @@ def apply_grounded_generation_template(
citation_mode: either "accurate" (prompt the model to generate an answer first, then rewrite it with citation
spans in) or "fast", where the prompt instructs the model to generate an answer with citations in directly.
The former has higher quality citations, the latter requires fewer tokens to be generated.
grounded_generation_template (str, *optional*): A Jinja template to use for this conversion. If
this is not passed, the model's default grounded_generation_template template will be used instead.
add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate
the start of an assistant message. This is useful when you want to generate a response from the model.
Note that this argument will be passed to the chat template, and so it must be supported in the
Expand Down Expand Up @@ -748,17 +680,10 @@ def apply_grounded_generation_template(
Answer: The Emperor Penguin is the tallest or biggest penguin in the world. It is a bird that lives only in Antarctica and grows to a height of around 122 centimetres.
Grounded answer: The <co: 0>Emperor Penguin</co: 0> is the <co: 0>tallest</co: 0> or biggest penguin in the world. It is a bird that <co: 1>lives only in Antarctica</co: 1> and <co: 0>grows to a height of around 122 centimetres.</co: 0>
"""
# priority: `grounded_generation_template` argument > `tokenizer.grounded_generation_template` > `tokenizer.default_grounded_generation_template`
if grounded_generation_template is None:
if self.grounded_generation_template is not None:
grounded_generation_template = self.grounded_generation_template
else:
grounded_generation_template = self.default_grounded_generation_template

return self._apply_template_with_arguments(
return self.apply_chat_template(
conversation,
chat_template="rag",
documents=documents,
template=grounded_generation_template,
citation_mode=citation_mode,
**kwargs,
)
Expand Down

0 comments on commit 24a2227

Please sign in to comment.