-
Notifications
You must be signed in to change notification settings - Fork 546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Auto-apply chat template in SequenceGenerator
and SequenceGeneratorAdapter
, if available
#1019
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import datetime | ||
import warnings | ||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING, Iterator, List, Optional, Union | ||
|
||
|
@@ -20,13 +21,15 @@ def __init__( | |
model, | ||
sampler, | ||
device, | ||
apply_chat_template: bool = True, | ||
): | ||
self.fsm = fsm | ||
self.model = model | ||
self.sampler = sampler | ||
self.tokenizer = model.tokenizer | ||
self.device = device | ||
self.num_samples = sampler.samples | ||
self.apply_chat_template = apply_chat_template | ||
|
||
def get_generated_token_ids( | ||
self, | ||
|
@@ -132,6 +135,7 @@ def __call__( | |
max_tokens: Optional[int] = None, | ||
stop_at: Optional[Union[str, List[str]]] = None, | ||
rng: Optional["torch.Generator"] = None, | ||
apply_chat_template: Optional[bool] = None, | ||
) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]: | ||
"""Generate the full text sequence. | ||
|
||
|
@@ -153,16 +157,25 @@ def __call__( | |
rng | ||
The random number generator. Defaults to a non-seeded `torch.Generator` | ||
instance. | ||
apply_chat_template | ||
Whether to apply the chat template to the prompts. Defaults to the value | ||
set at init. Only applies to `TransformerTokenizer` for now. | ||
|
||
Returns | ||
------- | ||
The generation(s), potentially cast to another type. | ||
""" | ||
if apply_chat_template is None: | ||
apply_chat_template = self.apply_chat_template | ||
|
||
import torch | ||
|
||
if isinstance(prompts, str): | ||
prompts = [prompts] | ||
|
||
if apply_chat_template: | ||
apply_chat_template_util(self.model, prompts) | ||
|
||
if isinstance(stop_at, str): | ||
stop_at = [stop_at] | ||
|
||
|
@@ -250,6 +263,7 @@ def stream( | |
max_tokens: Optional[int] = None, | ||
stop_at: Optional[Union[str, List[str]]] = None, | ||
rng: Optional["torch.Generator"] = None, | ||
apply_chat_template: Optional[bool] = None, | ||
) -> Iterator[Union[List[str], str, List[List[str]]]]: | ||
"""Generate the text sequence one token at a time. | ||
|
||
|
@@ -270,17 +284,26 @@ def stream( | |
rng | ||
The random number generator. Defaults to a non-seeded `torch.Generator` | ||
instance. | ||
apply_chat_template | ||
Whether to apply the chat template to the prompts. Defaults to the value | ||
set at init. Only applies to `TransformerTokenizer` for now. | ||
|
||
Returns | ||
------- | ||
A string or list of strings that contain the generated text. | ||
|
||
""" | ||
if apply_chat_template is None: | ||
apply_chat_template = self.apply_chat_template | ||
|
||
import torch | ||
|
||
if isinstance(prompts, str): | ||
prompts = [prompts] | ||
|
||
if apply_chat_template: | ||
apply_chat_template_util(self.model, prompts) | ||
|
||
if isinstance(stop_at, str): | ||
stop_at = [stop_at] | ||
|
||
|
@@ -423,7 +446,9 @@ class SequenceGeneratorAdapter: | |
|
||
""" | ||
|
||
def __init__(self, model, logits_processor, sampler): | ||
def __init__( | ||
self, model, logits_processor, sampler, apply_chat_template: bool = True | ||
): | ||
self.model = model | ||
self.logits_processor = logits_processor | ||
|
||
|
@@ -444,6 +469,8 @@ def __init__(self, model, logits_processor, sampler): | |
"beam_search", sampler.samples, None, None, 1.0 | ||
) | ||
|
||
self.apply_chat_template = apply_chat_template | ||
|
||
def prepare_generation_parameters( | ||
self, | ||
max_tokens: Optional[int], | ||
|
@@ -485,9 +512,15 @@ def __call__( | |
max_tokens: Optional[int] = None, | ||
stop_at: Optional[Union[str, List[str]]] = None, | ||
seed: Optional[int] = None, | ||
apply_chat_template: Optional[bool] = None, | ||
**model_specific_params, | ||
): | ||
"""Generate text from a prompt of list of prompts.""" | ||
if apply_chat_template is None: | ||
apply_chat_template = self.apply_chat_template | ||
|
||
if apply_chat_template: | ||
apply_chat_template_util(self.model, prompts) | ||
|
||
def format(sequences): | ||
"""Apply formatting to every string in a completion.""" | ||
|
@@ -516,9 +549,14 @@ def stream( | |
max_tokens: Optional[int] = None, | ||
stop_at: Optional[Union[str, List[str]]] = None, | ||
seed: Optional[int] = None, | ||
apply_chat_template: Optional[bool] = None, | ||
**model_specific_params, | ||
): | ||
"""Return a text generator from a prompt or a list of prompts.""" | ||
if apply_chat_template is None: | ||
apply_chat_template = self.apply_chat_template | ||
if apply_chat_template: | ||
apply_chat_template_util(self.model, prompts) | ||
generation_params = self.prepare_generation_parameters( | ||
max_tokens, stop_at, seed | ||
) | ||
|
@@ -529,3 +567,22 @@ def stream( | |
self.sampling_params, | ||
**model_specific_params, | ||
) | ||
|
||
|
||
def apply_chat_template_util(model, prompts: Union[str, List[str]]) -> List[str]: | ||
from outlines.models.transformers import TransformerTokenizer | ||
|
||
if isinstance(prompts, str): | ||
prompts = [prompts] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the signature should be |
||
if not isinstance(model.tokenizer, TransformerTokenizer): | ||
warnings.warn( | ||
"Chat template is only supported for `Transformer` models for now. The raw prompts will be used instead." | ||
) | ||
return prompts | ||
tokenizer: "TransformerTokenizer" = model.tokenizer | ||
if getattr(tokenizer.tokenizer, "chat_template", None) is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we use https://huggingface.co/docs/transformers/v4.34.0/chat_templating#how-do-chat-templates-work |
||
warnings.warn( | ||
"The model does not have chat template support. The raw prompts will be used instead. To turn this warning off, either explicitly set the `apply_chat_template` argument to 'False' or assign a value to `model.tokenizer.tokenizer.chat_template`." | ||
) | ||
return prompts | ||
return [tokenizer.apply_chat_template(prompt) for prompt in prompts] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make this pythonic, we should have one obvious way of applying a chat template. IMO the argument should only be accepted in the constructor.