Skip to content

Commit

Permalink
feat: Jamba instruct tokenizer (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
Josephasafg authored Mar 28, 2024
1 parent 1146741 commit 88ff9af
Show file tree
Hide file tree
Showing 10 changed files with 556 additions and 49 deletions.
12 changes: 10 additions & 2 deletions ai21_tokenizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from ai21_tokenizer.base_tokenizer import BaseTokenizer
from ai21_tokenizer.jamba_instruct_tokenizer import JambaInstructTokenizer
from ai21_tokenizer.jurassic_tokenizer import JurassicTokenizer
from ai21_tokenizer.tokenizer_factory import TokenizerFactory as Tokenizer
from ai21_tokenizer.tokenizer_factory import TokenizerFactory as Tokenizer, PreTrainedTokenizers
from .version import VERSION

__version__ = VERSION

__all__ = ["Tokenizer", "JurassicTokenizer", "BaseTokenizer", "__version__"]
__all__ = [
"Tokenizer",
"JurassicTokenizer",
"BaseTokenizer",
"__version__",
"PreTrainedTokenizers",
"JambaInstructTokenizer",
]
78 changes: 78 additions & 0 deletions ai21_tokenizer/jamba_instruct_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import annotations

import os
import tempfile
from pathlib import Path
from typing import Union, List, cast, Optional

from tokenizers import Tokenizer

from ai21_tokenizer import BaseTokenizer
from ai21_tokenizer.utils import PathLike

_TOKENIZER_FILE = "tokenizer.json"
_DEFAULT_MODEL_CACHE_DIR = Path(tempfile.gettempdir()) / "jamba_instruct"


class JambaInstructTokenizer(BaseTokenizer):
_tokenizer: Tokenizer

def __init__(
self,
model_path: str,
cache_dir: Optional[PathLike] = None,
):
"""
Args:
model_path: str
The identifier of a Model on the Hugging Face Hub, that contains a tokenizer.json file
cache_dir: Optional[PathLike]
The directory to cache the tokenizer.json file.
If not provided, the default cache directory will be used
"""
self._tokenizer = self._init_tokenizer(model_path=model_path, cache_dir=cache_dir or _DEFAULT_MODEL_CACHE_DIR)

def _init_tokenizer(self, model_path: PathLike, cache_dir: PathLike) -> Tokenizer:
if self._is_cached(cache_dir):
return self._load_from_cache(cache_dir / _TOKENIZER_FILE)

tokenizer = cast(
Tokenizer,
Tokenizer.from_pretrained(model_path),
)
self._cache_tokenizer(tokenizer, cache_dir)

return tokenizer

def _is_cached(self, cache_dir: PathLike) -> bool:
return Path(cache_dir).exists() and _TOKENIZER_FILE in os.listdir(cache_dir)

def _load_from_cache(self, cache_file: Path) -> Tokenizer:
return cast(Tokenizer, Tokenizer.from_file(str(cache_file)))

def _cache_tokenizer(self, tokenizer: Tokenizer, cache_dir: PathLike) -> None:
# create cache directory for caching the tokenizer and save it
Path(cache_dir).mkdir(parents=True, exist_ok=True)
tokenizer.save(str(cache_dir / _TOKENIZER_FILE))

def encode(self, text: str, **kwargs) -> List[int]:
return self._tokenizer.encode(text, **kwargs).ids

def decode(self, token_ids: List[int], **kwargs) -> str:
return self._tokenizer.decode(token_ids, **kwargs)

def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
if isinstance(tokens, str):
return self._tokenizer.token_to_id(tokens)

return [self._tokenizer.token_to_id(token) for token in tokens]

def convert_ids_to_tokens(self, token_ids: Union[int, List[int]], **kwargs) -> Union[str, List[str]]:
if isinstance(token_ids, int):
return self._tokenizer.id_to_token(token_ids)

return [self._tokenizer.id_to_token(token_id) for token_id in token_ids]

@property
def vocab_size(self) -> int:
return self._tokenizer.get_vocab_size()
25 changes: 22 additions & 3 deletions ai21_tokenizer/jurassic_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

import re
from dataclasses import dataclass
from pathlib import Path
from typing import List, Union, Optional, Dict, Any, Tuple, BinaryIO

import sentencepiece as spm

from ai21_tokenizer.base_tokenizer import BaseTokenizer
from ai21_tokenizer.utils import load_binary, is_number, PathLike
from ai21_tokenizer.utils import load_binary, is_number, PathLike, load_json

_MODEL_EXTENSION = ".model"
_MODEL_CONFIG_FILENAME = "config.json"


@dataclass
Expand All @@ -25,11 +29,11 @@ def __init__(
):
self._validate_init(model_path=model_path, model_file_handle=model_file_handle)

model_proto = load_binary(model_path) if model_path else model_file_handle.read()
model_proto = load_binary(self._get_model_file(model_path)) if model_path else model_file_handle.read()

# noinspection PyArgumentList
self._sp = spm.SentencePieceProcessor(model_proto=model_proto)
config = config or {}
config = self._get_config(model_path=model_path, config=config)

self.pad_id = config.get("pad_id")
self.unk_id = config.get("unk_id")
Expand Down Expand Up @@ -64,6 +68,21 @@ def _validate_init(self, model_path: Optional[PathLike], model_file_handle: Opti
if model_path is not None and model_file_handle is not None:
raise ValueError("Must provide exactly one of model_path or model_file_handle. Got both.")

def _get_model_file(self, model_path: PathLike) -> PathLike:
model_path = Path(model_path)

if model_path.is_dir():
return model_path / f"{model_path.name}{_MODEL_EXTENSION}"

return model_path

def _get_config(self, model_path: Optional[PathLike], config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
if model_path and Path(model_path).is_dir():
config_path = model_path / _MODEL_CONFIG_FILENAME
return load_json(config_path)

return config or {}

def _map_space_tokens(self) -> List[SpaceSymbol]:
res = []
for count in range(32, 0, -1):
Expand Down
37 changes: 11 additions & 26 deletions ai21_tokenizer/tokenizer_factory.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
import os
from pathlib import Path
from typing import Dict, Any

from ai21_tokenizer.base_tokenizer import BaseTokenizer
from ai21_tokenizer.jamba_instruct_tokenizer import JambaInstructTokenizer
from ai21_tokenizer.jurassic_tokenizer import JurassicTokenizer
from ai21_tokenizer.utils import load_json

_LOCAL_RESOURCES_PATH = Path(__file__).parent / "resources"
_MODEL_EXTENSION = ".model"
_MODEL_CONFIG_FILENAME = "config.json"
_ENV_CACHE_DIR_KEY = "AI21_TOKENIZER_CACHE_DIR"
JAMABA_TOKENIZER_HF_PATH = "ai21labs/Jamba-v0.1"


class PreTrainedTokenizers:
J2_TOKENIZER = "j2-tokenizer"


_PRETRAINED_MODEL_NAMES = [
PreTrainedTokenizers.J2_TOKENIZER,
]
JAMBA_INSTRUCT_TOKENIZER = "jamba-instruct-tokenizer"


class TokenizerFactory:
Expand All @@ -25,23 +21,12 @@ class TokenizerFactory:
Currently supports only J2-Tokenizer
"""

_tokenizer_name = PreTrainedTokenizers.J2_TOKENIZER

@classmethod
def get_tokenizer(cls) -> BaseTokenizer:
config = cls._get_config(cls._tokenizer_name)
model_path = cls._model_path(cls._tokenizer_name)
return JurassicTokenizer(model_path=model_path, config=config)
def get_tokenizer(cls, tokenizer_name: str = PreTrainedTokenizers.J2_TOKENIZER) -> BaseTokenizer:
if tokenizer_name == PreTrainedTokenizers.JAMBA_INSTRUCT_TOKENIZER:
return JambaInstructTokenizer(model_path=JAMABA_TOKENIZER_HF_PATH, cache_dir=os.getenv(_ENV_CACHE_DIR_KEY))

@classmethod
def _tokenizer_dir(cls, tokenizer_name: str) -> Path:
return _LOCAL_RESOURCES_PATH / tokenizer_name
if tokenizer_name == PreTrainedTokenizers.J2_TOKENIZER:
return JurassicTokenizer(_LOCAL_RESOURCES_PATH / PreTrainedTokenizers.J2_TOKENIZER)

@classmethod
def _model_path(cls, tokenizer_name: str) -> Path:
return cls._tokenizer_dir(tokenizer_name) / f"{tokenizer_name}{_MODEL_EXTENSION}"

@classmethod
def _get_config(cls, tokenizer_name: str) -> Dict[str, Any]:
config_path = cls._tokenizer_dir(tokenizer_name) / _MODEL_CONFIG_FILENAME
return load_json(config_path)
raise ValueError(f"Tokenizer {tokenizer_name} is not supported")
Loading

0 comments on commit 88ff9af

Please sign in to comment.