Skip to content

Commit

Permalink
Class detection works for huggingface checkpoints (#1800)
Browse files Browse the repository at this point in the history
* Class detection works for huggingface checkpoints

This was a tricky one to fix that involved some large
refactoring to our preset loading routines.

Originally the intent was that `from_preset()` was a easily
readable bunch of lower-level Keras calls. With the arrival
of transformers conversions, and soon timm conversions, I think
that goal is no longer super realistic.

Instead I added a loader interface, with default implementations
off `load_task` and `load_preprocessor`. Every format we support
directly converting from has to support at a minimum...
- Detecting the backbone class.
- Loading the backbone class.

One consequence of this work is that every class with a `from_preset`
constructor needs to reference the `backbone_cls` they match with. I
think this will be a more stable way to handle our "auto class" like
functionality as we venture further towards multi-modal models

* Address comments
  • Loading branch information
mattdangerw authored Aug 28, 2024
1 parent fbc1335 commit 0c04abe
Show file tree
Hide file tree
Showing 75 changed files with 623 additions and 585 deletions.
2 changes: 2 additions & 0 deletions keras_nlp/src/models/albert/albert_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from keras_nlp.src.layers.preprocessing.multi_segment_packer import (
MultiSegmentPacker,
)
from keras_nlp.src.models.albert.albert_backbone import AlbertBackbone
from keras_nlp.src.models.albert.albert_tokenizer import AlbertTokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
Expand Down Expand Up @@ -144,6 +145,7 @@ class AlbertPreprocessor(Preprocessor):
```
"""

backbone_cls = AlbertBackbone
tokenizer_cls = AlbertTokenizer

def __init__(
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/src/models/albert/albert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.albert.albert_backbone import AlbertBackbone
from keras_nlp.src.tokenizers.sentence_piece_tokenizer import (
SentencePieceTokenizer,
)
Expand Down Expand Up @@ -84,6 +85,8 @@ class AlbertTokenizer(SentencePieceTokenizer):
```
"""

backbone_cls = AlbertBackbone

def __init__(self, proto, **kwargs):
self.cls_token = "[CLS]"
self.sep_token = "[SEP]"
Expand Down
29 changes: 7 additions & 22 deletions keras_nlp/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,12 @@
from keras_nlp.src.utils.keras_utils import assert_quantization_support
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
from keras_nlp.src.utils.preset_utils import MODEL_WEIGHTS_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import check_format
from keras_nlp.src.utils.preset_utils import get_file
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup
from keras_nlp.src.utils.preset_utils import get_preset_loader
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import load_serialized_object
from keras_nlp.src.utils.preset_utils import save_metadata
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.python_utils import classproperty
from keras_nlp.src.utils.transformers.convert import load_transformers_backbone


@keras_nlp_export("keras_nlp.models.Backbone")
Expand Down Expand Up @@ -200,25 +195,15 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from
)
```
"""
format = check_format(preset)

if format == "transformers":
return load_transformers_backbone(cls, preset, load_weights)

preset_cls = check_config_class(preset)
if not issubclass(preset_cls, cls):
loader = get_preset_loader(preset)
backbone_cls = loader.check_backbone_class()
if not issubclass(backbone_cls, cls):
raise ValueError(
f"Preset has type `{preset_cls.__name__}` which is not a "
f"Saved preset has type `{backbone_cls.__name__}` which is not "
f"a subclass of calling class `{cls.__name__}`. Call "
f"`from_preset` directly on `{preset_cls.__name__}` instead."
f"`from_preset` directly on `{backbone_cls.__name__}` instead."
)

backbone = load_serialized_object(preset, CONFIG_FILE, **kwargs)
if load_weights:
jax_memory_cleanup(backbone)
backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))

return backbone
return loader.load_backbone(backbone_cls, load_weights, **kwargs)

def save_to_preset(self, preset_dir):
"""Save backbone to a preset directory.
Expand Down
8 changes: 4 additions & 4 deletions keras_nlp/src/models/backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from keras_nlp.src.utils.preset_utils import METADATA_FILE
from keras_nlp.src.utils.preset_utils import MODEL_WEIGHTS_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import load_config
from keras_nlp.src.utils.preset_utils import load_json


class TestBackbone(TestCase):
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_from_preset_errors(self):
GPT2Backbone.from_preset("bert_tiny_en_uncased", load_weights=False)
with self.assertRaises(ValueError):
# No loading on a non-keras model.
Backbone.from_preset("hf://google-bert/bert-base-uncased")
Backbone.from_preset("hf://spacy/en_core_web_sm")

@pytest.mark.large
def test_save_to_preset(self):
Expand All @@ -84,12 +84,12 @@ def test_save_to_preset(self):
self.assertTrue(os.path.exists(os.path.join(save_dir, METADATA_FILE)))

# Check the backbone config (`config.json`).
backbone_config = load_config(save_dir, CONFIG_FILE)
backbone_config = load_json(save_dir, CONFIG_FILE)
self.assertTrue("build_config" not in backbone_config)
self.assertTrue("compile_config" not in backbone_config)

# Try config class.
self.assertEqual(BertBackbone, check_config_class(save_dir))
self.assertEqual(BertBackbone, check_config_class(backbone_config))

# Try loading the model from preset directory.
restored_backbone = Backbone.from_preset(save_dir)
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/src/models/bart/bart_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker
from keras_nlp.src.models.bart.bart_backbone import BartBackbone
from keras_nlp.src.models.bart.bart_tokenizer import BartTokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
Expand Down Expand Up @@ -127,6 +128,7 @@ class BartPreprocessor(Preprocessor):
```
"""

backbone_cls = BartBackbone
tokenizer_cls = BartTokenizer

def __init__(
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/src/models/bart/bart_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.bart.bart_backbone import BartBackbone
from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer


Expand Down Expand Up @@ -73,6 +74,8 @@ class BartTokenizer(BytePairTokenizer):
```
"""

backbone_cls = BartBackbone

def __init__(
self,
vocabulary=None,
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/src/models/bert/bert_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from keras_nlp.src.layers.preprocessing.multi_segment_packer import (
MultiSegmentPacker,
)
from keras_nlp.src.models.bert.bert_backbone import BertBackbone
from keras_nlp.src.models.bert.bert_tokenizer import BertTokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
Expand Down Expand Up @@ -122,6 +123,7 @@ class BertPreprocessor(Preprocessor):
```
"""

backbone_cls = BertBackbone
tokenizer_cls = BertTokenizer

def __init__(
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/src/models/bert/bert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.bert.bert_backbone import BertBackbone
from keras_nlp.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer


Expand Down Expand Up @@ -68,6 +69,8 @@ class BertTokenizer(WordPieceTokenizer):
```
"""

backbone_cls = BertBackbone

def __init__(
self,
vocabulary=None,
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/src/models/bloom/bloom_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker
from keras_nlp.src.models.bloom.bloom_backbone import BloomBackbone
from keras_nlp.src.models.bloom.bloom_tokenizer import BloomTokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
Expand Down Expand Up @@ -103,6 +104,7 @@ class BloomPreprocessor(Preprocessor):
```
"""

backbone_cls = BloomBackbone
tokenizer_cls = BloomTokenizer

def __init__(
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/src/models/bloom/bloom_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.bloom.bloom_backbone import BloomBackbone
from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer


Expand Down Expand Up @@ -65,6 +66,8 @@ class BloomTokenizer(BytePairTokenizer):
```
"""

backbone_cls = BloomBackbone

def __init__(
self,
vocabulary=None,
Expand Down
4 changes: 4 additions & 0 deletions keras_nlp/src/models/deberta_v3/deberta_v3_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from keras_nlp.src.layers.preprocessing.multi_segment_packer import (
MultiSegmentPacker,
)
from keras_nlp.src.models.deberta_v3.deberta_v3_backbone import (
DebertaV3Backbone,
)
from keras_nlp.src.models.deberta_v3.deberta_v3_tokenizer import (
DebertaV3Tokenizer,
)
Expand Down Expand Up @@ -145,6 +148,7 @@ class DebertaV3Preprocessor(Preprocessor):
```
"""

backbone_cls = DebertaV3Backbone
tokenizer_cls = DebertaV3Tokenizer

def __init__(
Expand Down
5 changes: 5 additions & 0 deletions keras_nlp/src/models/deberta_v3/deberta_v3_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@


from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.deberta_v3.deberta_v3_backbone import (
DebertaV3Backbone,
)
from keras_nlp.src.tokenizers.sentence_piece_tokenizer import (
SentencePieceTokenizer,
)
Expand Down Expand Up @@ -94,6 +97,8 @@ class DebertaV3Tokenizer(SentencePieceTokenizer):
```
"""

backbone_cls = DebertaV3Backbone

def __init__(self, proto, **kwargs):
self.cls_token = "[CLS]"
self.sep_token = "[SEP]"
Expand Down
4 changes: 4 additions & 0 deletions keras_nlp/src/models/distil_bert/distil_bert_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from keras_nlp.src.layers.preprocessing.multi_segment_packer import (
MultiSegmentPacker,
)
from keras_nlp.src.models.distil_bert.distil_bert_backbone import (
DistilBertBackbone,
)
from keras_nlp.src.models.distil_bert.distil_bert_tokenizer import (
DistilBertTokenizer,
)
Expand Down Expand Up @@ -114,6 +117,7 @@ class DistilBertPreprocessor(Preprocessor):
```
"""

backbone_cls = DistilBertBackbone
tokenizer_cls = DistilBertTokenizer

def __init__(
Expand Down
5 changes: 5 additions & 0 deletions keras_nlp/src/models/distil_bert/distil_bert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@


from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.distil_bert.distil_bert_backbone import (
DistilBertBackbone,
)
from keras_nlp.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer


Expand Down Expand Up @@ -70,6 +73,8 @@ class DistilBertTokenizer(WordPieceTokenizer):
```
"""

backbone_cls = DistilBertBackbone

def __init__(
self,
vocabulary,
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/src/models/electra/electra_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from keras_nlp.src.layers.preprocessing.multi_segment_packer import (
MultiSegmentPacker,
)
from keras_nlp.src.models.electra.electra_backbone import ElectraBackbone
from keras_nlp.src.models.electra.electra_tokenizer import ElectraTokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
Expand Down Expand Up @@ -111,6 +112,7 @@ class ElectraPreprocessor(Preprocessor):
```
"""

backbone_cls = ElectraBackbone
tokenizer_cls = ElectraTokenizer

def __init__(
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/src/models/electra/electra_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.electra.electra_backbone import ElectraBackbone
from keras_nlp.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer


Expand Down Expand Up @@ -60,6 +61,8 @@ class ElectraTokenizer(WordPieceTokenizer):
```
"""

backbone_cls = ElectraBackbone

def __init__(
self,
vocabulary,
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/src/models/f_net/f_net_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from keras_nlp.src.layers.preprocessing.multi_segment_packer import (
MultiSegmentPacker,
)
from keras_nlp.src.models.f_net.f_net_backbone import FNetBackbone
from keras_nlp.src.models.f_net.f_net_tokenizer import FNetTokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
Expand Down Expand Up @@ -116,6 +117,7 @@ class FNetPreprocessor(Preprocessor):
```
"""

backbone_cls = FNetBackbone
tokenizer_cls = FNetTokenizer

def __init__(
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/src/models/f_net/f_net_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.f_net.f_net_backbone import FNetBackbone
from keras_nlp.src.tokenizers.sentence_piece_tokenizer import (
SentencePieceTokenizer,
)
Expand Down Expand Up @@ -61,6 +62,8 @@ class FNetTokenizer(SentencePieceTokenizer):
```
"""

backbone_cls = FNetBackbone

def __init__(self, proto, **kwargs):
self.cls_token = "[CLS]"
self.sep_token = "[SEP]"
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/src/models/falcon/falcon_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker
from keras_nlp.src.models.falcon.falcon_backbone import FalconBackbone
from keras_nlp.src.models.falcon.falcon_tokenizer import FalconTokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
Expand Down Expand Up @@ -105,6 +106,7 @@ class FalconPreprocessor(Preprocessor):
```
"""

backbone_cls = FalconBackbone
tokenizer_cls = FalconTokenizer

def __init__(
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/src/models/falcon/falcon_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.falcon.falcon_backbone import FalconBackbone
from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer


Expand Down Expand Up @@ -65,6 +66,8 @@ class FalconTokenizer(BytePairTokenizer):
```
"""

backbone_cls = FalconBackbone

def __init__(
self,
vocabulary=None,
Expand Down
Loading

0 comments on commit 0c04abe

Please sign in to comment.