diff --git a/clip/clip.py b/clip/clip.py index 55e1433f8..974ef060c 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -180,7 +180,7 @@ def patch_float(module): return model, _transform(model.input_resolution.item()) -def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: """ Returns the tokenized representation of given input string(s) @@ -192,6 +192,9 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.Lo context_length : int The context length to use; all CLIP models use 77 as the context length + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + Returns ------- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] @@ -206,7 +209,11 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.Lo for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: - raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") result[i, :len(tokens)] = torch.tensor(tokens) return result