Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move preprocessing to base classes #1807

Merged
merged 5 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 5 additions & 92 deletions keras_nlp/src/models/albert/albert_masked_lm_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import keras
from absl import logging

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.layers.preprocessing.masked_lm_mask_generator import (
MaskedLMMaskGenerator,
)
from keras_nlp.src.models.albert.albert_text_classifier_preprocessor import (
AlbertTextClassifierPreprocessor,
)
from keras_nlp.src.models.albert.albert_backbone import AlbertBackbone
from keras_nlp.src.models.albert.albert_tokenizer import AlbertTokenizer
from keras_nlp.src.models.masked_lm_preprocessor import MaskedLMPreprocessor
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function


@keras_nlp_export("keras_nlp.models.AlbertMaskedLMPreprocessor")
class AlbertMaskedLMPreprocessor(
AlbertTextClassifierPreprocessor, MaskedLMPreprocessor
):
class AlbertMaskedLMPreprocessor(MaskedLMPreprocessor):
"""ALBERT preprocessing for the masked language modeling task.
This preprocessing layer will prepare inputs for a masked language modeling
Expand Down Expand Up @@ -120,82 +110,5 @@ class AlbertMaskedLMPreprocessor(
```
"""

def __init__(
self,
tokenizer,
sequence_length=512,
truncate="round_robin",
mask_selection_rate=0.15,
mask_selection_length=96,
mask_token_rate=0.8,
random_token_rate=0.1,
**kwargs,
):
super().__init__(
tokenizer,
sequence_length=sequence_length,
truncate=truncate,
**kwargs,
)
self.mask_selection_rate = mask_selection_rate
self.mask_selection_length = mask_selection_length
self.mask_token_rate = mask_token_rate
self.random_token_rate = random_token_rate
self.masker = None

def build(self, input_shape):
super().build(input_shape)
# Defer masker creation to `build()` so that we can be sure tokenizer
# assets have loaded when restoring a saved model.
self.masker = MaskedLMMaskGenerator(
mask_selection_rate=self.mask_selection_rate,
mask_selection_length=self.mask_selection_length,
mask_token_rate=self.mask_token_rate,
random_token_rate=self.random_token_rate,
vocabulary_size=self.tokenizer.vocabulary_size(),
mask_token_id=self.tokenizer.mask_token_id,
unselectable_token_ids=[
self.tokenizer.cls_token_id,
self.tokenizer.sep_token_id,
self.tokenizer.pad_token_id,
],
)

def get_config(self):
config = super().get_config()
config.update(
{
"mask_selection_rate": self.mask_selection_rate,
"mask_selection_length": self.mask_selection_length,
"mask_token_rate": self.mask_token_rate,
"random_token_rate": self.random_token_rate,
}
)
return config

@tf_preprocessing_function
def call(self, x, y=None, sample_weight=None):
if y is not None or sample_weight is not None:
logging.warning(
f"{self.__class__.__name__} generates `y` and `sample_weight` "
"based on your input data, but your data already contains `y` "
"or `sample_weight`. Your `y` and `sample_weight` will be "
"ignored."
)

x = super().call(x)
token_ids, segment_ids, padding_mask = (
x["token_ids"],
x["segment_ids"],
x["padding_mask"],
)
masker_outputs = self.masker(token_ids)
x = {
"token_ids": masker_outputs["token_ids"],
"segment_ids": segment_ids,
"padding_mask": padding_mask,
"mask_positions": masker_outputs["mask_positions"],
}
y = masker_outputs["mask_ids"]
sample_weight = masker_outputs["mask_weights"]
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
backbone_cls = AlbertBackbone
tokenizer_cls = AlbertTokenizer
64 changes: 0 additions & 64 deletions keras_nlp/src/models/albert/albert_text_classifier_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import keras

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.layers.preprocessing.multi_segment_packer import (
MultiSegmentPacker,
)
from keras_nlp.src.models.albert.albert_backbone import AlbertBackbone
from keras_nlp.src.models.albert.albert_tokenizer import AlbertTokenizer
from keras_nlp.src.models.text_classifier_preprocessor import (
TextClassifierPreprocessor,
)
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function


@keras_nlp_export(
Expand Down Expand Up @@ -154,61 +148,3 @@ class AlbertTextClassifierPreprocessor(TextClassifierPreprocessor):

backbone_cls = AlbertBackbone
tokenizer_cls = AlbertTokenizer

def __init__(
self,
tokenizer,
sequence_length=512,
truncate="round_robin",
**kwargs,
):
super().__init__(**kwargs)
self.tokenizer = tokenizer
self.packer = None
self.truncate = truncate
self.sequence_length = sequence_length

def build(self, input_shape):
# Defer packer creation to `build()` so that we can be sure tokenizer
# assets have loaded when restoring a saved model.
self.packer = MultiSegmentPacker(
start_value=self.tokenizer.cls_token_id,
end_value=self.tokenizer.sep_token_id,
pad_value=self.tokenizer.pad_token_id,
truncate=self.truncate,
sequence_length=self.sequence_length,
)
self.built = True

def get_config(self):
config = super().get_config()
config.update(
{
"sequence_length": self.sequence_length,
"truncate": self.truncate,
}
)
return config

@tf_preprocessing_function
def call(self, x, y=None, sample_weight=None):
x = x if isinstance(x, tuple) else (x,)
x = tuple(self.tokenizer(segment) for segment in x)
token_ids, segment_ids = self.packer(x)
x = {
"token_ids": token_ids,
"segment_ids": segment_ids,
"padding_mask": token_ids != self.tokenizer.pad_token_id,
}
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)

@property
def sequence_length(self):
"""The padded length of model input sequences."""
return self._sequence_length

@sequence_length.setter
def sequence_length(self, value):
self._sequence_length = value
if self.packer is not None:
self.packer.sequence_length = value
39 changes: 8 additions & 31 deletions keras_nlp/src/models/albert/albert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,35 +89,12 @@ class AlbertTokenizer(SentencePieceTokenizer):
backbone_cls = AlbertBackbone

def __init__(self, proto, **kwargs):
self.cls_token = "[CLS]"
self.sep_token = "[SEP]"
self.pad_token = "<pad>"
self.mask_token = "[MASK]"

self._add_special_token("[CLS]", "cls_token")
self._add_special_token("[SEP]", "sep_token")
self._add_special_token("<pad>", "pad_token")
self._add_special_token("[MASK]", "mask_token")
# Also add `tokenizer.start_token` and `tokenizer.end_token` for
# compatibility with other tokenizers.
self._add_special_token("[CLS]", "start_token")
self._add_special_token("[SEP]", "end_token")
super().__init__(proto=proto, **kwargs)

def set_proto(self, proto):
super().set_proto(proto)
if proto is not None:
for token in [
self.cls_token,
self.sep_token,
self.pad_token,
self.mask_token,
]:
if token not in self.get_vocabulary():
raise ValueError(
f"Cannot find token `'{token}'` in the provided "
f"`vocabulary`. Please provide `'{token}'` in your "
"`vocabulary` or use a pretrained `vocabulary` name."
)

self.cls_token_id = self.token_to_id(self.cls_token)
self.sep_token_id = self.token_to_id(self.sep_token)
self.pad_token_id = self.token_to_id(self.pad_token)
self.mask_token_id = self.token_to_id(self.mask_token)
else:
self.cls_token_id = None
self.sep_token_id = None
self.pad_token_id = None
self.mask_token_id = None
4 changes: 2 additions & 2 deletions keras_nlp/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def from_preset(
to save and load a pre-trained model. The `preset` can be passed as a
one of:

1. a built in preset identifier like `'bert_base_en'`
1. a built-in preset identifier like `'bert_base_en'`
2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'`
3. a Hugging Face handle like `'hf://user/bert_base_en'`
4. a path to a local preset directory like `'./bert_base_en'`
Expand All @@ -175,7 +175,7 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from
all built-in presets available on the class.

Args:
preset: string. A built in preset identifier, a Kaggle Models
preset: string. A built-in preset identifier, a Kaggle Models
handle, a Hugging Face handle, or a path to a local directory.
load_weights: bool. If `True`, the weights will be loaded into the
model architecture. If `False`, the weights will be randomly
Expand Down
Loading
Loading