Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Class detection works for huggingface checkpoints #1800

Merged
merged 2 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions keras_nlp/src/models/albert/albert_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from keras_nlp.src.layers.preprocessing.multi_segment_packer import (
MultiSegmentPacker,
)
from keras_nlp.src.models.albert.albert_backbone import AlbertBackbone
from keras_nlp.src.models.albert.albert_tokenizer import AlbertTokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
Expand Down Expand Up @@ -144,6 +145,7 @@ class AlbertPreprocessor(Preprocessor):
```
"""

backbone_cls = AlbertBackbone
tokenizer_cls = AlbertTokenizer

def __init__(
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/src/models/albert/albert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.albert.albert_backbone import AlbertBackbone
from keras_nlp.src.tokenizers.sentence_piece_tokenizer import (
SentencePieceTokenizer,
)
Expand Down Expand Up @@ -84,6 +85,8 @@ class AlbertTokenizer(SentencePieceTokenizer):
```
"""

backbone_cls = AlbertBackbone

def __init__(self, proto, **kwargs):
self.cls_token = "[CLS]"
self.sep_token = "[SEP]"
Expand Down
29 changes: 7 additions & 22 deletions keras_nlp/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,12 @@
from keras_nlp.src.utils.keras_utils import assert_quantization_support
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
from keras_nlp.src.utils.preset_utils import MODEL_WEIGHTS_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import check_format
from keras_nlp.src.utils.preset_utils import get_file
from keras_nlp.src.utils.preset_utils import jax_memory_cleanup
from keras_nlp.src.utils.preset_utils import get_preset_loader
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import load_serialized_object
from keras_nlp.src.utils.preset_utils import save_metadata
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.python_utils import classproperty
from keras_nlp.src.utils.transformers.convert import load_transformers_backbone


@keras_nlp_export("keras_nlp.models.Backbone")
Expand Down Expand Up @@ -200,25 +195,15 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from
)
```
"""
format = check_format(preset)

if format == "transformers":
return load_transformers_backbone(cls, preset, load_weights)

preset_cls = check_config_class(preset)
if not issubclass(preset_cls, cls):
loader = get_preset_loader(preset)
backbone_cls = loader.check_backbone_class()
if not issubclass(backbone_cls, cls):
raise ValueError(
f"Preset has type `{preset_cls.__name__}` which is not a "
f"Saved preset has type `{backbone_cls.__name__}` which is not "
f"a subclass of calling class `{cls.__name__}`. Call "
f"`from_preset` directly on `{preset_cls.__name__}` instead."
f"`from_preset` directly on `{backbone_cls.__name__}` instead."
)

backbone = load_serialized_object(preset, CONFIG_FILE, **kwargs)
if load_weights:
jax_memory_cleanup(backbone)
backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))

return backbone
return loader.load_backbone(backbone_cls, load_weights, **kwargs)

def save_to_preset(self, preset_dir):
"""Save backbone to a preset directory.
Expand Down
8 changes: 4 additions & 4 deletions keras_nlp/src/models/backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from keras_nlp.src.utils.preset_utils import METADATA_FILE
from keras_nlp.src.utils.preset_utils import MODEL_WEIGHTS_FILE
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import load_config
from keras_nlp.src.utils.preset_utils import load_json


class TestBackbone(TestCase):
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_from_preset_errors(self):
GPT2Backbone.from_preset("bert_tiny_en_uncased", load_weights=False)
with self.assertRaises(ValueError):
# No loading on a non-keras model.
Backbone.from_preset("hf://google-bert/bert-base-uncased")
Backbone.from_preset("hf://spacy/en_core_web_sm")

@pytest.mark.large
def test_save_to_preset(self):
Expand All @@ -84,12 +84,12 @@ def test_save_to_preset(self):
self.assertTrue(os.path.exists(os.path.join(save_dir, METADATA_FILE)))

# Check the backbone config (`config.json`).
backbone_config = load_config(save_dir, CONFIG_FILE)
backbone_config = load_json(save_dir, CONFIG_FILE)
self.assertTrue("build_config" not in backbone_config)
self.assertTrue("compile_config" not in backbone_config)

# Try config class.
self.assertEqual(BertBackbone, check_config_class(save_dir))
self.assertEqual(BertBackbone, check_config_class(backbone_config))

# Try loading the model from preset directory.
restored_backbone = Backbone.from_preset(save_dir)
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/src/models/bart/bart_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker
from keras_nlp.src.models.bart.bart_backbone import BartBackbone
from keras_nlp.src.models.bart.bart_tokenizer import BartTokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
Expand Down Expand Up @@ -127,6 +128,7 @@ class BartPreprocessor(Preprocessor):
```
"""

backbone_cls = BartBackbone
tokenizer_cls = BartTokenizer

def __init__(
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/src/models/bart/bart_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.bart.bart_backbone import BartBackbone
from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer


Expand Down Expand Up @@ -73,6 +74,8 @@ class BartTokenizer(BytePairTokenizer):
```
"""

backbone_cls = BartBackbone

def __init__(
self,
vocabulary=None,
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/src/models/bert/bert_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from keras_nlp.src.layers.preprocessing.multi_segment_packer import (
MultiSegmentPacker,
)
from keras_nlp.src.models.bert.bert_backbone import BertBackbone
from keras_nlp.src.models.bert.bert_tokenizer import BertTokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
Expand Down Expand Up @@ -122,6 +123,7 @@ class BertPreprocessor(Preprocessor):
```
"""

backbone_cls = BertBackbone
tokenizer_cls = BertTokenizer

def __init__(
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/src/models/bert/bert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.bert.bert_backbone import BertBackbone
from keras_nlp.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer


Expand Down Expand Up @@ -68,6 +69,8 @@ class BertTokenizer(WordPieceTokenizer):
```
"""

backbone_cls = BertBackbone

def __init__(
self,
vocabulary=None,
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/src/models/bloom/bloom_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker
from keras_nlp.src.models.bloom.bloom_backbone import BloomBackbone
from keras_nlp.src.models.bloom.bloom_tokenizer import BloomTokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
Expand Down Expand Up @@ -103,6 +104,7 @@ class BloomPreprocessor(Preprocessor):
```
"""

backbone_cls = BloomBackbone
tokenizer_cls = BloomTokenizer

def __init__(
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/src/models/bloom/bloom_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.bloom.bloom_backbone import BloomBackbone
from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer


Expand Down Expand Up @@ -65,6 +66,8 @@ class BloomTokenizer(BytePairTokenizer):
```
"""

backbone_cls = BloomBackbone

def __init__(
self,
vocabulary=None,
Expand Down
4 changes: 4 additions & 0 deletions keras_nlp/src/models/deberta_v3/deberta_v3_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from keras_nlp.src.layers.preprocessing.multi_segment_packer import (
MultiSegmentPacker,
)
from keras_nlp.src.models.deberta_v3.deberta_v3_backbone import (
DebertaV3Backbone,
)
from keras_nlp.src.models.deberta_v3.deberta_v3_tokenizer import (
DebertaV3Tokenizer,
)
Expand Down Expand Up @@ -145,6 +148,7 @@ class DebertaV3Preprocessor(Preprocessor):
```
"""

backbone_cls = DebertaV3Backbone
tokenizer_cls = DebertaV3Tokenizer

def __init__(
Expand Down
5 changes: 5 additions & 0 deletions keras_nlp/src/models/deberta_v3/deberta_v3_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@


from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.deberta_v3.deberta_v3_backbone import (
DebertaV3Backbone,
)
from keras_nlp.src.tokenizers.sentence_piece_tokenizer import (
SentencePieceTokenizer,
)
Expand Down Expand Up @@ -94,6 +97,8 @@ class DebertaV3Tokenizer(SentencePieceTokenizer):
```
"""

backbone_cls = DebertaV3Backbone

def __init__(self, proto, **kwargs):
self.cls_token = "[CLS]"
self.sep_token = "[SEP]"
Expand Down
4 changes: 4 additions & 0 deletions keras_nlp/src/models/distil_bert/distil_bert_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from keras_nlp.src.layers.preprocessing.multi_segment_packer import (
MultiSegmentPacker,
)
from keras_nlp.src.models.distil_bert.distil_bert_backbone import (
DistilBertBackbone,
)
from keras_nlp.src.models.distil_bert.distil_bert_tokenizer import (
DistilBertTokenizer,
)
Expand Down Expand Up @@ -114,6 +117,7 @@ class DistilBertPreprocessor(Preprocessor):
```
"""

backbone_cls = DistilBertBackbone
tokenizer_cls = DistilBertTokenizer

def __init__(
Expand Down
5 changes: 5 additions & 0 deletions keras_nlp/src/models/distil_bert/distil_bert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@


from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.distil_bert.distil_bert_backbone import (
DistilBertBackbone,
)
from keras_nlp.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer


Expand Down Expand Up @@ -70,6 +73,8 @@ class DistilBertTokenizer(WordPieceTokenizer):
```
"""

backbone_cls = DistilBertBackbone

def __init__(
self,
vocabulary,
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/src/models/electra/electra_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from keras_nlp.src.layers.preprocessing.multi_segment_packer import (
MultiSegmentPacker,
)
from keras_nlp.src.models.electra.electra_backbone import ElectraBackbone
from keras_nlp.src.models.electra.electra_tokenizer import ElectraTokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
Expand Down Expand Up @@ -111,6 +112,7 @@ class ElectraPreprocessor(Preprocessor):
```
"""

backbone_cls = ElectraBackbone
tokenizer_cls = ElectraTokenizer

def __init__(
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/src/models/electra/electra_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.electra.electra_backbone import ElectraBackbone
from keras_nlp.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer


Expand Down Expand Up @@ -60,6 +61,8 @@ class ElectraTokenizer(WordPieceTokenizer):
```
"""

backbone_cls = ElectraBackbone

def __init__(
self,
vocabulary,
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/src/models/f_net/f_net_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from keras_nlp.src.layers.preprocessing.multi_segment_packer import (
MultiSegmentPacker,
)
from keras_nlp.src.models.f_net.f_net_backbone import FNetBackbone
from keras_nlp.src.models.f_net.f_net_tokenizer import FNetTokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
Expand Down Expand Up @@ -116,6 +117,7 @@ class FNetPreprocessor(Preprocessor):
```
"""

backbone_cls = FNetBackbone
tokenizer_cls = FNetTokenizer

def __init__(
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/src/models/f_net/f_net_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.f_net.f_net_backbone import FNetBackbone
from keras_nlp.src.tokenizers.sentence_piece_tokenizer import (
SentencePieceTokenizer,
)
Expand Down Expand Up @@ -61,6 +62,8 @@ class FNetTokenizer(SentencePieceTokenizer):
```
"""

backbone_cls = FNetBackbone

def __init__(self, proto, **kwargs):
self.cls_token = "[CLS]"
self.sep_token = "[SEP]"
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/src/models/falcon/falcon_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker
from keras_nlp.src.models.falcon.falcon_backbone import FalconBackbone
from keras_nlp.src.models.falcon.falcon_tokenizer import FalconTokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function
Expand Down Expand Up @@ -105,6 +106,7 @@ class FalconPreprocessor(Preprocessor):
```
"""

backbone_cls = FalconBackbone
tokenizer_cls = FalconTokenizer

def __init__(
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/src/models/falcon/falcon_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.falcon.falcon_backbone import FalconBackbone
from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer


Expand Down Expand Up @@ -65,6 +66,8 @@ class FalconTokenizer(BytePairTokenizer):
```
"""

backbone_cls = FalconBackbone

def __init__(
self,
vocabulary=None,
Expand Down
Loading
Loading