diff --git a/keras_nlp/src/models/albert/albert_preprocessor.py b/keras_nlp/src/models/albert/albert_preprocessor.py index 1e502a0289..c9980701da 100644 --- a/keras_nlp/src/models/albert/albert_preprocessor.py +++ b/keras_nlp/src/models/albert/albert_preprocessor.py @@ -18,6 +18,7 @@ from keras_nlp.src.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) +from keras_nlp.src.models.albert.albert_backbone import AlbertBackbone from keras_nlp.src.models.albert.albert_tokenizer import AlbertTokenizer from keras_nlp.src.models.preprocessor import Preprocessor from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @@ -144,6 +145,7 @@ class AlbertPreprocessor(Preprocessor): ``` """ + backbone_cls = AlbertBackbone tokenizer_cls = AlbertTokenizer def __init__( diff --git a/keras_nlp/src/models/albert/albert_tokenizer.py b/keras_nlp/src/models/albert/albert_tokenizer.py index 37223aa76c..c39768a5b9 100644 --- a/keras_nlp/src/models/albert/albert_tokenizer.py +++ b/keras_nlp/src/models/albert/albert_tokenizer.py @@ -13,6 +13,7 @@ # limitations under the License. from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.albert.albert_backbone import AlbertBackbone from keras_nlp.src.tokenizers.sentence_piece_tokenizer import ( SentencePieceTokenizer, ) @@ -84,6 +85,8 @@ class AlbertTokenizer(SentencePieceTokenizer): ``` """ + backbone_cls = AlbertBackbone + def __init__(self, proto, **kwargs): self.cls_token = "[CLS]" self.sep_token = "[SEP]" diff --git a/keras_nlp/src/models/backbone.py b/keras_nlp/src/models/backbone.py index d74f03ccd7..b252761e43 100644 --- a/keras_nlp/src/models/backbone.py +++ b/keras_nlp/src/models/backbone.py @@ -20,17 +20,12 @@ from keras_nlp.src.utils.keras_utils import assert_quantization_support from keras_nlp.src.utils.preset_utils import CONFIG_FILE from keras_nlp.src.utils.preset_utils import MODEL_WEIGHTS_FILE -from keras_nlp.src.utils.preset_utils import check_config_class -from keras_nlp.src.utils.preset_utils import check_format -from keras_nlp.src.utils.preset_utils import get_file -from keras_nlp.src.utils.preset_utils import jax_memory_cleanup +from keras_nlp.src.utils.preset_utils import get_preset_loader from keras_nlp.src.utils.preset_utils import list_presets from keras_nlp.src.utils.preset_utils import list_subclasses -from keras_nlp.src.utils.preset_utils import load_serialized_object from keras_nlp.src.utils.preset_utils import save_metadata from keras_nlp.src.utils.preset_utils import save_serialized_object from keras_nlp.src.utils.python_utils import classproperty -from keras_nlp.src.utils.transformers.convert import load_transformers_backbone @keras_nlp_export("keras_nlp.models.Backbone") @@ -200,25 +195,15 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from ) ``` """ - format = check_format(preset) - - if format == "transformers": - return load_transformers_backbone(cls, preset, load_weights) - - preset_cls = check_config_class(preset) - if not issubclass(preset_cls, cls): + loader = get_preset_loader(preset) + backbone_cls = loader.check_backbone_class() + if not issubclass(backbone_cls, cls): raise ValueError( - f"Preset has type `{preset_cls.__name__}` which is not a " + f"Saved preset has type `{backbone_cls.__name__}` which is not " f"a subclass of calling class `{cls.__name__}`. Call " - f"`from_preset` directly on `{preset_cls.__name__}` instead." + f"`from_preset` directly on `{backbone_cls.__name__}` instead." ) - - backbone = load_serialized_object(preset, CONFIG_FILE, **kwargs) - if load_weights: - jax_memory_cleanup(backbone) - backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE)) - - return backbone + return loader.load_backbone(backbone_cls, load_weights, **kwargs) def save_to_preset(self, preset_dir): """Save backbone to a preset directory. diff --git a/keras_nlp/src/models/backbone_test.py b/keras_nlp/src/models/backbone_test.py index 966806592b..d707bec9fb 100644 --- a/keras_nlp/src/models/backbone_test.py +++ b/keras_nlp/src/models/backbone_test.py @@ -24,7 +24,7 @@ from keras_nlp.src.utils.preset_utils import METADATA_FILE from keras_nlp.src.utils.preset_utils import MODEL_WEIGHTS_FILE from keras_nlp.src.utils.preset_utils import check_config_class -from keras_nlp.src.utils.preset_utils import load_config +from keras_nlp.src.utils.preset_utils import load_json class TestBackbone(TestCase): @@ -68,7 +68,7 @@ def test_from_preset_errors(self): GPT2Backbone.from_preset("bert_tiny_en_uncased", load_weights=False) with self.assertRaises(ValueError): # No loading on a non-keras model. - Backbone.from_preset("hf://google-bert/bert-base-uncased") + Backbone.from_preset("hf://spacy/en_core_web_sm") @pytest.mark.large def test_save_to_preset(self): @@ -84,12 +84,12 @@ def test_save_to_preset(self): self.assertTrue(os.path.exists(os.path.join(save_dir, METADATA_FILE))) # Check the backbone config (`config.json`). - backbone_config = load_config(save_dir, CONFIG_FILE) + backbone_config = load_json(save_dir, CONFIG_FILE) self.assertTrue("build_config" not in backbone_config) self.assertTrue("compile_config" not in backbone_config) # Try config class. - self.assertEqual(BertBackbone, check_config_class(save_dir)) + self.assertEqual(BertBackbone, check_config_class(backbone_config)) # Try loading the model from preset directory. restored_backbone = Backbone.from_preset(save_dir) diff --git a/keras_nlp/src/models/bart/bart_preprocessor.py b/keras_nlp/src/models/bart/bart_preprocessor.py index c16713e7cc..dc013779a3 100644 --- a/keras_nlp/src/models/bart/bart_preprocessor.py +++ b/keras_nlp/src/models/bart/bart_preprocessor.py @@ -17,6 +17,7 @@ from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.src.models.bart.bart_backbone import BartBackbone from keras_nlp.src.models.bart.bart_tokenizer import BartTokenizer from keras_nlp.src.models.preprocessor import Preprocessor from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @@ -127,6 +128,7 @@ class BartPreprocessor(Preprocessor): ``` """ + backbone_cls = BartBackbone tokenizer_cls = BartTokenizer def __init__( diff --git a/keras_nlp/src/models/bart/bart_tokenizer.py b/keras_nlp/src/models/bart/bart_tokenizer.py index 0e69c6ebda..115c7a60a1 100644 --- a/keras_nlp/src/models/bart/bart_tokenizer.py +++ b/keras_nlp/src/models/bart/bart_tokenizer.py @@ -14,6 +14,7 @@ from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.bart.bart_backbone import BartBackbone from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer @@ -73,6 +74,8 @@ class BartTokenizer(BytePairTokenizer): ``` """ + backbone_cls = BartBackbone + def __init__( self, vocabulary=None, diff --git a/keras_nlp/src/models/bert/bert_preprocessor.py b/keras_nlp/src/models/bert/bert_preprocessor.py index 581ef2f457..eab9f971ec 100644 --- a/keras_nlp/src/models/bert/bert_preprocessor.py +++ b/keras_nlp/src/models/bert/bert_preprocessor.py @@ -18,6 +18,7 @@ from keras_nlp.src.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) +from keras_nlp.src.models.bert.bert_backbone import BertBackbone from keras_nlp.src.models.bert.bert_tokenizer import BertTokenizer from keras_nlp.src.models.preprocessor import Preprocessor from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @@ -122,6 +123,7 @@ class BertPreprocessor(Preprocessor): ``` """ + backbone_cls = BertBackbone tokenizer_cls = BertTokenizer def __init__( diff --git a/keras_nlp/src/models/bert/bert_tokenizer.py b/keras_nlp/src/models/bert/bert_tokenizer.py index 79aa916eed..6fd6b32ef7 100644 --- a/keras_nlp/src/models/bert/bert_tokenizer.py +++ b/keras_nlp/src/models/bert/bert_tokenizer.py @@ -13,6 +13,7 @@ # limitations under the License. from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.bert.bert_backbone import BertBackbone from keras_nlp.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer @@ -68,6 +69,8 @@ class BertTokenizer(WordPieceTokenizer): ``` """ + backbone_cls = BertBackbone + def __init__( self, vocabulary=None, diff --git a/keras_nlp/src/models/bloom/bloom_preprocessor.py b/keras_nlp/src/models/bloom/bloom_preprocessor.py index 16a05fb8f1..6c572591b7 100644 --- a/keras_nlp/src/models/bloom/bloom_preprocessor.py +++ b/keras_nlp/src/models/bloom/bloom_preprocessor.py @@ -17,6 +17,7 @@ from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.src.models.bloom.bloom_backbone import BloomBackbone from keras_nlp.src.models.bloom.bloom_tokenizer import BloomTokenizer from keras_nlp.src.models.preprocessor import Preprocessor from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @@ -103,6 +104,7 @@ class BloomPreprocessor(Preprocessor): ``` """ + backbone_cls = BloomBackbone tokenizer_cls = BloomTokenizer def __init__( diff --git a/keras_nlp/src/models/bloom/bloom_tokenizer.py b/keras_nlp/src/models/bloom/bloom_tokenizer.py index ffe215b6e5..d935c1aa73 100644 --- a/keras_nlp/src/models/bloom/bloom_tokenizer.py +++ b/keras_nlp/src/models/bloom/bloom_tokenizer.py @@ -14,6 +14,7 @@ from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.bloom.bloom_backbone import BloomBackbone from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer @@ -65,6 +66,8 @@ class BloomTokenizer(BytePairTokenizer): ``` """ + backbone_cls = BloomBackbone + def __init__( self, vocabulary=None, diff --git a/keras_nlp/src/models/deberta_v3/deberta_v3_preprocessor.py b/keras_nlp/src/models/deberta_v3/deberta_v3_preprocessor.py index ad5bcaf318..e45e9f4b4d 100644 --- a/keras_nlp/src/models/deberta_v3/deberta_v3_preprocessor.py +++ b/keras_nlp/src/models/deberta_v3/deberta_v3_preprocessor.py @@ -19,6 +19,9 @@ from keras_nlp.src.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) +from keras_nlp.src.models.deberta_v3.deberta_v3_backbone import ( + DebertaV3Backbone, +) from keras_nlp.src.models.deberta_v3.deberta_v3_tokenizer import ( DebertaV3Tokenizer, ) @@ -145,6 +148,7 @@ class DebertaV3Preprocessor(Preprocessor): ``` """ + backbone_cls = DebertaV3Backbone tokenizer_cls = DebertaV3Tokenizer def __init__( diff --git a/keras_nlp/src/models/deberta_v3/deberta_v3_tokenizer.py b/keras_nlp/src/models/deberta_v3/deberta_v3_tokenizer.py index d9417af24b..7a0bb8f2dd 100644 --- a/keras_nlp/src/models/deberta_v3/deberta_v3_tokenizer.py +++ b/keras_nlp/src/models/deberta_v3/deberta_v3_tokenizer.py @@ -14,6 +14,9 @@ from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.deberta_v3.deberta_v3_backbone import ( + DebertaV3Backbone, +) from keras_nlp.src.tokenizers.sentence_piece_tokenizer import ( SentencePieceTokenizer, ) @@ -94,6 +97,8 @@ class DebertaV3Tokenizer(SentencePieceTokenizer): ``` """ + backbone_cls = DebertaV3Backbone + def __init__(self, proto, **kwargs): self.cls_token = "[CLS]" self.sep_token = "[SEP]" diff --git a/keras_nlp/src/models/distil_bert/distil_bert_preprocessor.py b/keras_nlp/src/models/distil_bert/distil_bert_preprocessor.py index 2a1fd67cd0..3e5b57dacf 100644 --- a/keras_nlp/src/models/distil_bert/distil_bert_preprocessor.py +++ b/keras_nlp/src/models/distil_bert/distil_bert_preprocessor.py @@ -19,6 +19,9 @@ from keras_nlp.src.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) +from keras_nlp.src.models.distil_bert.distil_bert_backbone import ( + DistilBertBackbone, +) from keras_nlp.src.models.distil_bert.distil_bert_tokenizer import ( DistilBertTokenizer, ) @@ -114,6 +117,7 @@ class DistilBertPreprocessor(Preprocessor): ``` """ + backbone_cls = DistilBertBackbone tokenizer_cls = DistilBertTokenizer def __init__( diff --git a/keras_nlp/src/models/distil_bert/distil_bert_tokenizer.py b/keras_nlp/src/models/distil_bert/distil_bert_tokenizer.py index d0e818699e..f0b615f84d 100644 --- a/keras_nlp/src/models/distil_bert/distil_bert_tokenizer.py +++ b/keras_nlp/src/models/distil_bert/distil_bert_tokenizer.py @@ -14,6 +14,9 @@ from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.distil_bert.distil_bert_backbone import ( + DistilBertBackbone, +) from keras_nlp.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer @@ -70,6 +73,8 @@ class DistilBertTokenizer(WordPieceTokenizer): ``` """ + backbone_cls = DistilBertBackbone + def __init__( self, vocabulary, diff --git a/keras_nlp/src/models/electra/electra_preprocessor.py b/keras_nlp/src/models/electra/electra_preprocessor.py index 7252532328..941e5d4d13 100644 --- a/keras_nlp/src/models/electra/electra_preprocessor.py +++ b/keras_nlp/src/models/electra/electra_preprocessor.py @@ -18,6 +18,7 @@ from keras_nlp.src.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) +from keras_nlp.src.models.electra.electra_backbone import ElectraBackbone from keras_nlp.src.models.electra.electra_tokenizer import ElectraTokenizer from keras_nlp.src.models.preprocessor import Preprocessor from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @@ -111,6 +112,7 @@ class ElectraPreprocessor(Preprocessor): ``` """ + backbone_cls = ElectraBackbone tokenizer_cls = ElectraTokenizer def __init__( diff --git a/keras_nlp/src/models/electra/electra_tokenizer.py b/keras_nlp/src/models/electra/electra_tokenizer.py index b4d87c7b28..6b51e6d5de 100644 --- a/keras_nlp/src/models/electra/electra_tokenizer.py +++ b/keras_nlp/src/models/electra/electra_tokenizer.py @@ -13,6 +13,7 @@ # limitations under the License. from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.electra.electra_backbone import ElectraBackbone from keras_nlp.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer @@ -60,6 +61,8 @@ class ElectraTokenizer(WordPieceTokenizer): ``` """ + backbone_cls = ElectraBackbone + def __init__( self, vocabulary, diff --git a/keras_nlp/src/models/f_net/f_net_preprocessor.py b/keras_nlp/src/models/f_net/f_net_preprocessor.py index e88a8a7dc3..bc188e4c45 100644 --- a/keras_nlp/src/models/f_net/f_net_preprocessor.py +++ b/keras_nlp/src/models/f_net/f_net_preprocessor.py @@ -19,6 +19,7 @@ from keras_nlp.src.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) +from keras_nlp.src.models.f_net.f_net_backbone import FNetBackbone from keras_nlp.src.models.f_net.f_net_tokenizer import FNetTokenizer from keras_nlp.src.models.preprocessor import Preprocessor from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @@ -116,6 +117,7 @@ class FNetPreprocessor(Preprocessor): ``` """ + backbone_cls = FNetBackbone tokenizer_cls = FNetTokenizer def __init__( diff --git a/keras_nlp/src/models/f_net/f_net_tokenizer.py b/keras_nlp/src/models/f_net/f_net_tokenizer.py index 4cfbb10207..df2b61558b 100644 --- a/keras_nlp/src/models/f_net/f_net_tokenizer.py +++ b/keras_nlp/src/models/f_net/f_net_tokenizer.py @@ -14,6 +14,7 @@ from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.f_net.f_net_backbone import FNetBackbone from keras_nlp.src.tokenizers.sentence_piece_tokenizer import ( SentencePieceTokenizer, ) @@ -61,6 +62,8 @@ class FNetTokenizer(SentencePieceTokenizer): ``` """ + backbone_cls = FNetBackbone + def __init__(self, proto, **kwargs): self.cls_token = "[CLS]" self.sep_token = "[SEP]" diff --git a/keras_nlp/src/models/falcon/falcon_preprocessor.py b/keras_nlp/src/models/falcon/falcon_preprocessor.py index 4de46aa00b..1c7fd3c138 100644 --- a/keras_nlp/src/models/falcon/falcon_preprocessor.py +++ b/keras_nlp/src/models/falcon/falcon_preprocessor.py @@ -17,6 +17,7 @@ from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.src.models.falcon.falcon_backbone import FalconBackbone from keras_nlp.src.models.falcon.falcon_tokenizer import FalconTokenizer from keras_nlp.src.models.preprocessor import Preprocessor from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @@ -105,6 +106,7 @@ class FalconPreprocessor(Preprocessor): ``` """ + backbone_cls = FalconBackbone tokenizer_cls = FalconTokenizer def __init__( diff --git a/keras_nlp/src/models/falcon/falcon_tokenizer.py b/keras_nlp/src/models/falcon/falcon_tokenizer.py index 7d00459008..9781a78782 100644 --- a/keras_nlp/src/models/falcon/falcon_tokenizer.py +++ b/keras_nlp/src/models/falcon/falcon_tokenizer.py @@ -14,6 +14,7 @@ from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.falcon.falcon_backbone import FalconBackbone from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer @@ -65,6 +66,8 @@ class FalconTokenizer(BytePairTokenizer): ``` """ + backbone_cls = FalconBackbone + def __init__( self, vocabulary=None, diff --git a/keras_nlp/src/models/gemma/gemma_preprocessor.py b/keras_nlp/src/models/gemma/gemma_preprocessor.py index 8788a403d2..745f437ecd 100644 --- a/keras_nlp/src/models/gemma/gemma_preprocessor.py +++ b/keras_nlp/src/models/gemma/gemma_preprocessor.py @@ -17,6 +17,7 @@ from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.src.models.gemma.gemma_backbone import GemmaBackbone from keras_nlp.src.models.gemma.gemma_tokenizer import GemmaTokenizer from keras_nlp.src.models.preprocessor import Preprocessor from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @@ -120,6 +121,7 @@ class GemmaPreprocessor(Preprocessor): ``` """ + backbone_cls = GemmaBackbone tokenizer_cls = GemmaTokenizer def __init__( diff --git a/keras_nlp/src/models/gemma/gemma_tokenizer.py b/keras_nlp/src/models/gemma/gemma_tokenizer.py index e87dee1b2e..7f79bb9184 100644 --- a/keras_nlp/src/models/gemma/gemma_tokenizer.py +++ b/keras_nlp/src/models/gemma/gemma_tokenizer.py @@ -13,6 +13,7 @@ # limitations under the License. from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.gemma.gemma_backbone import GemmaBackbone from keras_nlp.src.tokenizers.sentence_piece_tokenizer import ( SentencePieceTokenizer, ) @@ -77,6 +78,8 @@ class GemmaTokenizer(SentencePieceTokenizer): ``` """ + backbone_cls = GemmaBackbone + def __init__(self, proto, **kwargs): self.start_token = "" self.end_token = "" diff --git a/keras_nlp/src/models/gpt2/gpt2_preprocessor.py b/keras_nlp/src/models/gpt2/gpt2_preprocessor.py index 4e3958c402..b645ac7682 100644 --- a/keras_nlp/src/models/gpt2/gpt2_preprocessor.py +++ b/keras_nlp/src/models/gpt2/gpt2_preprocessor.py @@ -17,6 +17,7 @@ from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.src.models.gpt2.gpt2_backbone import GPT2Backbone from keras_nlp.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer from keras_nlp.src.models.preprocessor import Preprocessor from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @@ -105,6 +106,7 @@ class GPT2Preprocessor(Preprocessor): ``` """ + backbone_cls = GPT2Backbone tokenizer_cls = GPT2Tokenizer def __init__( diff --git a/keras_nlp/src/models/gpt2/gpt2_tokenizer.py b/keras_nlp/src/models/gpt2/gpt2_tokenizer.py index 37ee715c96..2994d81550 100644 --- a/keras_nlp/src/models/gpt2/gpt2_tokenizer.py +++ b/keras_nlp/src/models/gpt2/gpt2_tokenizer.py @@ -14,6 +14,7 @@ from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.gpt2.gpt2_backbone import GPT2Backbone from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer @@ -65,6 +66,8 @@ class GPT2Tokenizer(BytePairTokenizer): ``` """ + backbone_cls = GPT2Backbone + def __init__( self, vocabulary=None, diff --git a/keras_nlp/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py b/keras_nlp/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py index 2fcdb56f4b..25875d03da 100644 --- a/keras_nlp/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +++ b/keras_nlp/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py @@ -16,6 +16,7 @@ from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_backbone import GPTNeoXBackbone from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer from keras_nlp.src.models.preprocessor import Preprocessor from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @@ -63,6 +64,7 @@ class GPTNeoXPreprocessor(Preprocessor): the layer. """ + backbone_cls = GPTNeoXBackbone tokenizer_cls = GPTNeoXTokenizer def __init__( diff --git a/keras_nlp/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py b/keras_nlp/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py index b4a0c76cd8..30d397b16c 100644 --- a/keras_nlp/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +++ b/keras_nlp/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py @@ -13,6 +13,7 @@ # limitations under the License. from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.gpt_neo_x.gpt_neo_x_backbone import GPTNeoXBackbone from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer @@ -43,6 +44,8 @@ class GPTNeoXTokenizer(BytePairTokenizer): merge entities separated by a space. """ + backbone_cls = GPTNeoXBackbone + def __init__( self, vocabulary=None, diff --git a/keras_nlp/src/models/llama/llama_preprocessor.py b/keras_nlp/src/models/llama/llama_preprocessor.py index c4250a1389..8b7c5772e6 100644 --- a/keras_nlp/src/models/llama/llama_preprocessor.py +++ b/keras_nlp/src/models/llama/llama_preprocessor.py @@ -15,6 +15,7 @@ from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.src.models.llama.llama_backbone import LlamaBackbone from keras_nlp.src.models.llama.llama_tokenizer import LlamaTokenizer from keras_nlp.src.models.preprocessor import Preprocessor from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @@ -108,6 +109,7 @@ class LlamaPreprocessor(Preprocessor): ``` """ + backbone_cls = LlamaBackbone tokenizer_cls = LlamaTokenizer def __init__( diff --git a/keras_nlp/src/models/llama/llama_tokenizer.py b/keras_nlp/src/models/llama/llama_tokenizer.py index 10a3e849c8..df99492904 100644 --- a/keras_nlp/src/models/llama/llama_tokenizer.py +++ b/keras_nlp/src/models/llama/llama_tokenizer.py @@ -13,6 +13,7 @@ # limitations under the License. from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.llama.llama_backbone import LlamaBackbone from keras_nlp.src.tokenizers.sentence_piece_tokenizer import ( SentencePieceTokenizer, ) @@ -60,6 +61,8 @@ class LlamaTokenizer(SentencePieceTokenizer): ``` """ + backbone_cls = LlamaBackbone + def __init__(self, proto, **kwargs): self.start_token = "" self.end_token = "" diff --git a/keras_nlp/src/models/llama3/llama3_preprocessor.py b/keras_nlp/src/models/llama3/llama3_preprocessor.py index 0b767c27c7..b1fc5769ab 100644 --- a/keras_nlp/src/models/llama3/llama3_preprocessor.py +++ b/keras_nlp/src/models/llama3/llama3_preprocessor.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer from keras_nlp.src.models.llama.llama_preprocessor import LlamaPreprocessor @keras_nlp_export("keras_nlp.models.Llama3Preprocessor") class Llama3Preprocessor(LlamaPreprocessor): + backbone_cls = Llama3Backbone tokenizer_cls = Llama3Tokenizer diff --git a/keras_nlp/src/models/llama3/llama3_tokenizer.py b/keras_nlp/src/models/llama3/llama3_tokenizer.py index 2b22210edd..c1bcd684be 100644 --- a/keras_nlp/src/models/llama3/llama3_tokenizer.py +++ b/keras_nlp/src/models/llama3/llama3_tokenizer.py @@ -13,11 +13,14 @@ # limitations under the License. from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer @keras_nlp_export("keras_nlp.models.Llama3Tokenizer") class Llama3Tokenizer(BytePairTokenizer): + backbone_cls = Llama3Backbone + def __init__( self, vocabulary=None, diff --git a/keras_nlp/src/models/mistral/mistral_preprocessor.py b/keras_nlp/src/models/mistral/mistral_preprocessor.py index f2e434abb5..0278103c54 100644 --- a/keras_nlp/src/models/mistral/mistral_preprocessor.py +++ b/keras_nlp/src/models/mistral/mistral_preprocessor.py @@ -16,6 +16,7 @@ from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.src.models.mistral.mistral_backbone import MistralBackbone from keras_nlp.src.models.mistral.mistral_tokenizer import MistralTokenizer from keras_nlp.src.models.preprocessor import Preprocessor from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @@ -109,6 +110,7 @@ class MistralPreprocessor(Preprocessor): ``` """ + backbone_cls = MistralBackbone tokenizer_cls = MistralTokenizer def __init__( diff --git a/keras_nlp/src/models/mistral/mistral_tokenizer.py b/keras_nlp/src/models/mistral/mistral_tokenizer.py index 4355bdcd0c..47208dfa68 100644 --- a/keras_nlp/src/models/mistral/mistral_tokenizer.py +++ b/keras_nlp/src/models/mistral/mistral_tokenizer.py @@ -13,6 +13,7 @@ # limitations under the License. from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.mistral.mistral_backbone import MistralBackbone from keras_nlp.src.tokenizers.sentence_piece_tokenizer import ( SentencePieceTokenizer, ) @@ -60,6 +61,8 @@ class MistralTokenizer(SentencePieceTokenizer): ``` """ + backbone_cls = MistralBackbone + def __init__(self, proto, **kwargs): self.start_token = "" self.end_token = "" diff --git a/keras_nlp/src/models/opt/opt_preprocessor.py b/keras_nlp/src/models/opt/opt_preprocessor.py index 31d6256838..3b39c82cbc 100644 --- a/keras_nlp/src/models/opt/opt_preprocessor.py +++ b/keras_nlp/src/models/opt/opt_preprocessor.py @@ -17,6 +17,7 @@ from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.src.models.opt.opt_backbone import OPTBackbone from keras_nlp.src.models.opt.opt_tokenizer import OPTTokenizer from keras_nlp.src.models.preprocessor import Preprocessor from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @@ -105,6 +106,7 @@ class OPTPreprocessor(Preprocessor): ``` """ + backbone_cls = OPTBackbone tokenizer_cls = OPTTokenizer def __init__( diff --git a/keras_nlp/src/models/opt/opt_tokenizer.py b/keras_nlp/src/models/opt/opt_tokenizer.py index b12cb156c9..1a98b5d3ef 100644 --- a/keras_nlp/src/models/opt/opt_tokenizer.py +++ b/keras_nlp/src/models/opt/opt_tokenizer.py @@ -14,6 +14,7 @@ from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.opt.opt_backbone import OPTBackbone from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer @@ -65,6 +66,8 @@ class OPTTokenizer(BytePairTokenizer): ``` """ + backbone_cls = OPTBackbone + def __init__( self, vocabulary=None, diff --git a/keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py b/keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py index 70a96e9952..bd4a29d3c3 100644 --- a/keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +++ b/keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py @@ -21,6 +21,9 @@ from keras_nlp.src.models.gemma.gemma_causal_lm_preprocessor import ( GemmaCausalLMPreprocessor, ) +from keras_nlp.src.models.pali_gemma.pali_gemma_backbone import ( + PaliGemmaBackbone, +) from keras_nlp.src.models.pali_gemma.pali_gemma_tokenizer import ( PaliGemmaTokenizer, ) @@ -29,6 +32,7 @@ @keras_nlp_export("keras_nlp.models.PaliGemmaCausalLMPreprocessor") class PaliGemmaCausalLMPreprocessor(GemmaCausalLMPreprocessor): + backbone_cls = PaliGemmaBackbone tokenizer_cls = PaliGemmaTokenizer def __init__( diff --git a/keras_nlp/src/models/pali_gemma/pali_gemma_tokenizer.py b/keras_nlp/src/models/pali_gemma/pali_gemma_tokenizer.py index 9abd35c1ec..a274974ac0 100644 --- a/keras_nlp/src/models/pali_gemma/pali_gemma_tokenizer.py +++ b/keras_nlp/src/models/pali_gemma/pali_gemma_tokenizer.py @@ -13,6 +13,9 @@ # limitations under the License. from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.models.gemma.gemma_preprocessor import GemmaTokenizer +from keras_nlp.src.models.pali_gemma.pali_gemma_backbone import ( + PaliGemmaBackbone, +) @keras_nlp_export("keras_nlp.models.PaliGemmaTokenizer") @@ -76,4 +79,6 @@ class PaliGemmaTokenizer(GemmaTokenizer): ``` """ + backbone_cls = PaliGemmaBackbone + pass diff --git a/keras_nlp/src/models/phi3/phi3_causal_lm.py b/keras_nlp/src/models/phi3/phi3_causal_lm.py index 802669f31c..a567782b38 100644 --- a/keras_nlp/src/models/phi3/phi3_causal_lm.py +++ b/keras_nlp/src/models/phi3/phi3_causal_lm.py @@ -19,7 +19,6 @@ from keras_nlp.src.models.phi3.phi3_causal_lm_preprocessor import ( Phi3CausalLMPreprocessor, ) -from keras_nlp.src.utils.python_utils import classproperty from keras_nlp.src.utils.tensor_utils import any_equal @@ -46,6 +45,9 @@ class Phi3CausalLM(CausalLM): should be preprocessed before calling the model. """ + backbone_cls = Phi3Backbone + preprocessor_cls = Phi3CausalLMPreprocessor + def __init__(self, backbone, preprocessor=None, **kwargs): # === Layers === self.backbone = backbone @@ -61,14 +63,6 @@ def __init__(self, backbone, preprocessor=None, **kwargs): **kwargs, ) - @classproperty - def backbone_cls(cls): - return Phi3Backbone - - @classproperty - def preprocessor_cls(cls): - return Phi3CausalLMPreprocessor - def call_with_cache( self, token_ids, diff --git a/keras_nlp/src/models/phi3/phi3_preprocessor.py b/keras_nlp/src/models/phi3/phi3_preprocessor.py index 3a9f291cd2..caa5c9eab4 100644 --- a/keras_nlp/src/models/phi3/phi3_preprocessor.py +++ b/keras_nlp/src/models/phi3/phi3_preprocessor.py @@ -15,6 +15,7 @@ from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer from keras_nlp.src.models.preprocessor import Preprocessor from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @@ -108,6 +109,7 @@ class Phi3Preprocessor(Preprocessor): ``` """ + backbone_cls = Phi3Backbone tokenizer_cls = Phi3Tokenizer def __init__( diff --git a/keras_nlp/src/models/phi3/phi3_tokenizer.py b/keras_nlp/src/models/phi3/phi3_tokenizer.py index d45201ff6b..7d535cb74b 100644 --- a/keras_nlp/src/models/phi3/phi3_tokenizer.py +++ b/keras_nlp/src/models/phi3/phi3_tokenizer.py @@ -14,6 +14,7 @@ import copy from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone from keras_nlp.src.models.phi3.phi3_presets import backbone_presets from keras_nlp.src.tokenizers.sentence_piece_tokenizer import ( SentencePieceTokenizer, @@ -63,6 +64,8 @@ class Phi3Tokenizer(SentencePieceTokenizer): ``` """ + backbone_cls = Phi3Backbone + def __init__(self, proto, **kwargs): self.start_token = "" self.end_token = "<|endoftext|>" diff --git a/keras_nlp/src/models/preprocessor.py b/keras_nlp/src/models/preprocessor.py index 3d5b7ce40b..e67cad47a6 100644 --- a/keras_nlp/src/models/preprocessor.py +++ b/keras_nlp/src/models/preprocessor.py @@ -19,13 +19,10 @@ PreprocessingLayer, ) from keras_nlp.src.utils.preset_utils import PREPROCESSOR_CONFIG_FILE -from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE -from keras_nlp.src.utils.preset_utils import check_config_class -from keras_nlp.src.utils.preset_utils import check_file_exists -from keras_nlp.src.utils.preset_utils import check_format +from keras_nlp.src.utils.preset_utils import find_subclass +from keras_nlp.src.utils.preset_utils import get_preset_loader from keras_nlp.src.utils.preset_utils import list_presets from keras_nlp.src.utils.preset_utils import list_subclasses -from keras_nlp.src.utils.preset_utils import load_serialized_object from keras_nlp.src.utils.preset_utils import save_serialized_object from keras_nlp.src.utils.python_utils import classproperty @@ -45,6 +42,7 @@ class Preprocessor(PreprocessingLayer): should set the `tokenizer` property on construction. """ + backbone_cls = None tokenizer_cls = None def __init__(self, *args, **kwargs): @@ -128,70 +126,19 @@ def from_preset( ) ``` """ - format = check_format(preset) - - if format == "transformers": - if cls.tokenizer_cls is None: - raise ValueError("Tokenizer class is None") - tokenizer = cls.tokenizer_cls.from_preset(preset) - return cls(tokenizer=tokenizer, **kwargs) - if cls == Preprocessor: raise ValueError( "Do not call `Preprocessor.from_preset()` directly. Instead call a " "choose a particular task class, e.g. " "`keras_nlp.models.BertPreprocessor.from_preset()`." ) - # Check if we should load a `preprocessor.json` directly. - load_preprocessor_config = False - if check_file_exists(preset, PREPROCESSOR_CONFIG_FILE): - preprocessor_preset_cls = check_config_class( - preset, PREPROCESSOR_CONFIG_FILE - ) - if issubclass(preprocessor_preset_cls, cls): - load_preprocessor_config = True - if load_preprocessor_config: - # Preprocessor case. - preprocessor = load_serialized_object( - preset, - PREPROCESSOR_CONFIG_FILE, - ) - preprocessor.tokenizer.load_preset_assets(preset) - return preprocessor - - # Tokenizer case. - # If `preprocessor.json` doesn't exist or preprocessor preset class is - # different from the calling class, create the preprocessor based on - # `tokenizer.json`. - tokenizer_preset_cls = check_config_class( - preset, config_file=TOKENIZER_CONFIG_FILE - ) - if tokenizer_preset_cls is not cls.tokenizer_cls: - subclasses = list_subclasses(cls) - subclasses = tuple( - filter( - lambda x: x.tokenizer_cls == tokenizer_preset_cls, - subclasses, - ) - ) - if len(subclasses) == 0: - raise ValueError( - f"No registered subclass of `{cls.__name__}` can load " - f"a `{tokenizer_preset_cls.__name__}`." - ) - if len(subclasses) > 1: - names = ", ".join(f"`{x.__name__}`" for x in subclasses) - raise ValueError( - f"Ambiguous call to `{cls.__name__}.from_preset()`. " - f"Found multiple possible subclasses {names}. " - "Please call `from_preset` on a subclass directly." - ) - - tokenizer = load_serialized_object(preset, TOKENIZER_CONFIG_FILE) - tokenizer.load_preset_assets(preset) - preprocessor = cls(tokenizer=tokenizer, **kwargs) - - return preprocessor + + loader = get_preset_loader(preset) + backbone_cls = loader.check_backbone_class() + # Detect the correct subclass if we need to. + if cls.backbone_cls != backbone_cls: + cls = find_subclass(preset, cls, backbone_cls) + return loader.load_preprocessor(cls, **kwargs) def save_to_preset(self, preset_dir): """Save preprocessor to a preset directory. diff --git a/keras_nlp/src/models/preprocessor_test.py b/keras_nlp/src/models/preprocessor_test.py index 1bfbe9b4f0..6e79af3975 100644 --- a/keras_nlp/src/models/preprocessor_test.py +++ b/keras_nlp/src/models/preprocessor_test.py @@ -30,9 +30,10 @@ from keras_nlp.src.utils.preset_utils import PREPROCESSOR_CONFIG_FILE from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR from keras_nlp.src.utils.preset_utils import check_config_class +from keras_nlp.src.utils.preset_utils import load_json -class TestTask(TestCase): +class TestPreprocessor(TestCase): def test_preset_accessors(self): bert_presets = set(BertPreprocessor.presets.keys()) gpt2_presets = set(GPT2Preprocessor.presets.keys()) @@ -68,7 +69,7 @@ def test_from_preset_errors(self): BertPreprocessor.from_preset("gpt2_base_en") with self.assertRaises(ValueError): # No loading on a non-keras model. - Preprocessor.from_preset("hf://google-bert/bert-base-uncased") + BertPreprocessor.from_preset("hf://spacy/en_core_web_sm") # TODO: Add more tests when we added a model that has `preprocessor.json`. @@ -109,6 +110,5 @@ def test_save_to_preset(self, cls, preset_name, tokenizer_type): ) # Check config class. - self.assertEqual( - cls, check_config_class(save_dir, PREPROCESSOR_CONFIG_FILE) - ) + preprocessor_config = load_json(save_dir, PREPROCESSOR_CONFIG_FILE) + self.assertEqual(cls, check_config_class(preprocessor_config)) diff --git a/keras_nlp/src/models/roberta/roberta_preprocessor.py b/keras_nlp/src/models/roberta/roberta_preprocessor.py index 1ed9f2dbad..23280f47f1 100644 --- a/keras_nlp/src/models/roberta/roberta_preprocessor.py +++ b/keras_nlp/src/models/roberta/roberta_preprocessor.py @@ -20,6 +20,7 @@ MultiSegmentPacker, ) from keras_nlp.src.models.preprocessor import Preprocessor +from keras_nlp.src.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.src.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @@ -129,6 +130,7 @@ class RobertaPreprocessor(Preprocessor): ``` """ + backbone_cls = RobertaBackbone tokenizer_cls = RobertaTokenizer def __init__( diff --git a/keras_nlp/src/models/roberta/roberta_tokenizer.py b/keras_nlp/src/models/roberta/roberta_tokenizer.py index 9daaa7c199..23ac1fa46c 100644 --- a/keras_nlp/src/models/roberta/roberta_tokenizer.py +++ b/keras_nlp/src/models/roberta/roberta_tokenizer.py @@ -14,6 +14,7 @@ from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer @@ -72,6 +73,8 @@ class RobertaTokenizer(BytePairTokenizer): ``` """ + backbone_cls = RobertaBackbone + def __init__( self, vocabulary=None, diff --git a/keras_nlp/src/models/t5/t5_tokenizer.py b/keras_nlp/src/models/t5/t5_tokenizer.py index 9dccd0d80b..43f31ee9ab 100644 --- a/keras_nlp/src/models/t5/t5_tokenizer.py +++ b/keras_nlp/src/models/t5/t5_tokenizer.py @@ -13,6 +13,7 @@ # limitations under the License. from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.t5.t5_backbone import T5Backbone from keras_nlp.src.tokenizers.sentence_piece_tokenizer import ( SentencePieceTokenizer, ) @@ -74,6 +75,8 @@ class T5Tokenizer(SentencePieceTokenizer): ``` """ + backbone_cls = T5Backbone + def __init__(self, proto, **kwargs): self.end_token = "" self.pad_token = "" diff --git a/keras_nlp/src/models/task.py b/keras_nlp/src/models/task.py index abee4ecf29..86a7708cfc 100644 --- a/keras_nlp/src/models/task.py +++ b/keras_nlp/src/models/task.py @@ -22,18 +22,12 @@ from keras_nlp.src.api_export import keras_nlp_export from keras_nlp.src.utils.keras_utils import print_msg from keras_nlp.src.utils.pipeline_model import PipelineModel -from keras_nlp.src.utils.preset_utils import CONFIG_FILE -from keras_nlp.src.utils.preset_utils import MODEL_WEIGHTS_FILE from keras_nlp.src.utils.preset_utils import TASK_CONFIG_FILE from keras_nlp.src.utils.preset_utils import TASK_WEIGHTS_FILE -from keras_nlp.src.utils.preset_utils import check_config_class -from keras_nlp.src.utils.preset_utils import check_file_exists -from keras_nlp.src.utils.preset_utils import check_format -from keras_nlp.src.utils.preset_utils import get_file -from keras_nlp.src.utils.preset_utils import jax_memory_cleanup +from keras_nlp.src.utils.preset_utils import find_subclass +from keras_nlp.src.utils.preset_utils import get_preset_loader from keras_nlp.src.utils.preset_utils import list_presets from keras_nlp.src.utils.preset_utils import list_subclasses -from keras_nlp.src.utils.preset_utils import load_serialized_object from keras_nlp.src.utils.preset_utils import save_serialized_object from keras_nlp.src.utils.python_utils import classproperty @@ -195,18 +189,6 @@ def from_preset( ) ``` """ - format = check_format(preset) - - if format == "transformers": - if cls.backbone_cls is None: - raise ValueError("Backbone class is None") - if cls.preprocessor_cls is None: - raise ValueError("Preprocessor class is None") - - backbone = cls.backbone_cls.from_preset(preset) - preprocessor = cls.preprocessor_cls.from_preset(preset) - return cls(backbone=backbone, preprocessor=preprocessor, **kwargs) - if cls == Task: raise ValueError( "Do not call `Task.from_preset()` directly. Instead call a " @@ -214,69 +196,13 @@ def from_preset( "`keras_nlp.models.Classifier.from_preset()` or " "`keras_nlp.models.BertClassifier.from_preset()`." ) - if "backbone" in kwargs: - raise ValueError( - "You cannot pass a `backbone` argument to the `from_preset` " - f"method. Instead, call the {cls.__name__} default " - "constructor with a `backbone` argument. " - f"Received: backbone={kwargs['backbone']}." - ) - # Check if we should load a `task.json` directly. - load_task_config = False - if check_file_exists(preset, TASK_CONFIG_FILE): - task_preset_cls = check_config_class(preset, TASK_CONFIG_FILE) - if issubclass(task_preset_cls, cls): - load_task_config = True - if load_task_config: - # Task case. - task_preset_cls = check_config_class(preset, TASK_CONFIG_FILE) - task = load_serialized_object(preset, TASK_CONFIG_FILE) - if load_weights: - jax_memory_cleanup(task) - if check_file_exists(preset, TASK_WEIGHTS_FILE): - task.load_task_weights(get_file(preset, TASK_WEIGHTS_FILE)) - task.backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE)) - task.preprocessor.tokenizer.load_preset_assets(preset) - return task - - # Backbone case. - # If `task.json` doesn't exist or the task preset class is different - # from the calling class, create the task based on `config.json`. - backbone_preset_cls = check_config_class(preset, CONFIG_FILE) - if backbone_preset_cls is not cls.backbone_cls: - subclasses = list_subclasses(cls) - subclasses = tuple( - filter( - lambda x: x.backbone_cls == backbone_preset_cls, - subclasses, - ) - ) - if len(subclasses) == 0: - raise ValueError( - f"No registered subclass of `{cls.__name__}` can load " - f"a `{backbone_preset_cls.__name__}`." - ) - if len(subclasses) > 1: - names = ", ".join(f"`{x.__name__}`" for x in subclasses) - raise ValueError( - f"Ambiguous call to `{cls.__name__}.from_preset()`. " - f"Found multiple possible subclasses {names}. " - "Please call `from_preset` on a subclass directly." - ) - cls = subclasses[0] - # Forward dtype to the backbone. - backbone_kwargs = {} - if "dtype" in kwargs: - backbone_kwargs = {"dtype": kwargs.pop("dtype")} - backbone = backbone_preset_cls.from_preset( - preset, load_weights=load_weights, **backbone_kwargs - ) - if "preprocessor" in kwargs: - preprocessor = kwargs.pop("preprocessor") - else: - preprocessor = cls.preprocessor_cls.from_preset(preset) - return cls(backbone=backbone, preprocessor=preprocessor, **kwargs) + loader = get_preset_loader(preset) + backbone_cls = loader.check_backbone_class() + # Detect the correct subclass if we need to. + if cls.backbone_cls != backbone_cls: + cls = find_subclass(preset, cls, backbone_cls) + return loader.load_task(cls, load_weights, **kwargs) def load_task_weights(self, filepath): """Load only the tasks specific weights not in the backbone.""" diff --git a/keras_nlp/src/models/task_test.py b/keras_nlp/src/models/task_test.py index 2b02c41d46..dc084bd8a8 100644 --- a/keras_nlp/src/models/task_test.py +++ b/keras_nlp/src/models/task_test.py @@ -31,7 +31,7 @@ from keras_nlp.src.utils.preset_utils import TASK_CONFIG_FILE from keras_nlp.src.utils.preset_utils import TASK_WEIGHTS_FILE from keras_nlp.src.utils.preset_utils import check_config_class -from keras_nlp.src.utils.preset_utils import load_config +from keras_nlp.src.utils.preset_utils import load_json class SimpleTokenizer(Tokenizer): @@ -91,7 +91,7 @@ def test_from_preset_errors(self): BertClassifier.from_preset("gpt2_base_en", load_weights=False) with self.assertRaises(ValueError): # No loading on a non-keras model. - CausalLM.from_preset("hf://google-bert/bert-base-uncased") + CausalLM.from_preset("hf://spacy/en_core_web_sm") def test_summary_with_preprocessor(self): preprocessor = SimplePreprocessor() @@ -126,16 +126,14 @@ def test_save_to_preset(self): ) # Check the task config (`task.json`). - task_config = load_config(save_dir, TASK_CONFIG_FILE) + task_config = load_json(save_dir, TASK_CONFIG_FILE) self.assertTrue("build_config" not in task_config) self.assertTrue("compile_config" not in task_config) self.assertTrue("backbone" in task_config["config"]) self.assertTrue("preprocessor" in task_config["config"]) # Check the preset directory task class. - self.assertEqual( - BertClassifier, check_config_class(save_dir, TASK_CONFIG_FILE) - ) + self.assertEqual(BertClassifier, check_config_class(task_config)) # Try loading the model from preset directory. restored_model = Classifier.from_preset(save_dir) diff --git a/keras_nlp/src/models/whisper/whisper_preprocessor.py b/keras_nlp/src/models/whisper/whisper_preprocessor.py index 8659fd8319..bc83139952 100644 --- a/keras_nlp/src/models/whisper/whisper_preprocessor.py +++ b/keras_nlp/src/models/whisper/whisper_preprocessor.py @@ -22,6 +22,7 @@ from keras_nlp.src.models.whisper.whisper_audio_feature_extractor import ( WhisperAudioFeatureExtractor, ) +from keras_nlp.src.models.whisper.whisper_backbone import WhisperBackbone from keras_nlp.src.models.whisper.whisper_tokenizer import WhisperTokenizer from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function @@ -148,6 +149,7 @@ class WhisperPreprocessor(Preprocessor): ``` """ + backbone_cls = WhisperBackbone tokenizer_cls = WhisperTokenizer def __init__( diff --git a/keras_nlp/src/models/whisper/whisper_tokenizer.py b/keras_nlp/src/models/whisper/whisper_tokenizer.py index e917f1b780..43760f1e6f 100644 --- a/keras_nlp/src/models/whisper/whisper_tokenizer.py +++ b/keras_nlp/src/models/whisper/whisper_tokenizer.py @@ -15,6 +15,7 @@ import json from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.whisper.whisper_backbone import WhisperBackbone from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer @@ -47,6 +48,8 @@ class WhisperTokenizer(BytePairTokenizer): tokenizer. """ + backbone_cls = WhisperBackbone + def __init__( self, vocabulary=None, diff --git a/keras_nlp/src/models/xlm_roberta/xlm_roberta_preprocessor.py b/keras_nlp/src/models/xlm_roberta/xlm_roberta_preprocessor.py index 634e3240a9..61586bdd22 100644 --- a/keras_nlp/src/models/xlm_roberta/xlm_roberta_preprocessor.py +++ b/keras_nlp/src/models/xlm_roberta/xlm_roberta_preprocessor.py @@ -20,6 +20,9 @@ MultiSegmentPacker, ) from keras_nlp.src.models.preprocessor import Preprocessor +from keras_nlp.src.models.xlm_roberta.xlm_roberta_backbone import ( + XLMRobertaBackbone, +) from keras_nlp.src.models.xlm_roberta.xlm_roberta_tokenizer import ( XLMRobertaTokenizer, ) @@ -142,6 +145,7 @@ def train_sentencepiece(ds, vocab_size): ``` """ + backbone_cls = XLMRobertaBackbone tokenizer_cls = XLMRobertaTokenizer def __init__( diff --git a/keras_nlp/src/models/xlm_roberta/xlm_roberta_tokenizer.py b/keras_nlp/src/models/xlm_roberta/xlm_roberta_tokenizer.py index 50dce83349..78e9975f0d 100644 --- a/keras_nlp/src/models/xlm_roberta/xlm_roberta_tokenizer.py +++ b/keras_nlp/src/models/xlm_roberta/xlm_roberta_tokenizer.py @@ -14,6 +14,9 @@ from keras_nlp.src.api_export import keras_nlp_export +from keras_nlp.src.models.xlm_roberta.xlm_roberta_backbone import ( + XLMRobertaBackbone, +) from keras_nlp.src.tokenizers.sentence_piece_tokenizer import ( SentencePieceTokenizer, ) @@ -89,6 +92,8 @@ def train_sentencepiece(ds, vocab_size): ``` """ + backbone_cls = XLMRobertaBackbone + def __init__(self, proto, **kwargs): # List of special tokens. self._vocabulary_prefix = ["", "", "", ""] diff --git a/keras_nlp/src/tokenizers/tokenizer.py b/keras_nlp/src/tokenizers/tokenizer.py index b35aa53da6..81c7b50176 100644 --- a/keras_nlp/src/tokenizers/tokenizer.py +++ b/keras_nlp/src/tokenizers/tokenizer.py @@ -19,17 +19,15 @@ ) from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE -from keras_nlp.src.utils.preset_utils import check_config_class -from keras_nlp.src.utils.preset_utils import check_format +from keras_nlp.src.utils.preset_utils import find_subclass from keras_nlp.src.utils.preset_utils import get_file +from keras_nlp.src.utils.preset_utils import get_preset_loader from keras_nlp.src.utils.preset_utils import list_presets from keras_nlp.src.utils.preset_utils import list_subclasses -from keras_nlp.src.utils.preset_utils import load_serialized_object from keras_nlp.src.utils.preset_utils import save_serialized_object from keras_nlp.src.utils.preset_utils import save_tokenizer_assets from keras_nlp.src.utils.python_utils import classproperty from keras_nlp.src.utils.tensor_utils import tf_preprocessing_function -from keras_nlp.src.utils.transformers.convert import load_transformers_tokenizer @keras_nlp_export( @@ -80,6 +78,8 @@ def detokenize(self, inputs): ``` """ + backbone_cls = None + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.file_assets = None @@ -209,7 +209,7 @@ class like `keras_nlp.models.Tokenizer.from_preset()`, or from Examples: ```python # Load a preset tokenizer. - tokenizer = keras_nlp.tokenizerTokenizer.from_preset("bert_base_en") + tokenizer = keras_nlp.tokenizer.Tokenizer.from_preset("bert_base_en") # Tokenize some input. tokenizer("The quick brown fox tripped.") @@ -218,20 +218,8 @@ class like `keras_nlp.models.Tokenizer.from_preset()`, or from tokenizer.detokenize([5, 6, 7, 8, 9]) ``` """ - format = check_format(preset) - if format == "transformers": - return load_transformers_tokenizer(cls, preset) - - preset_cls = check_config_class( - preset, config_file=TOKENIZER_CONFIG_FILE - ) - if not issubclass(preset_cls, cls): - raise ValueError( - f"Preset has type `{preset_cls.__name__}` which is not a " - f"a subclass of calling class `{cls.__name__}`. Call " - f"`from_preset` directly on `{preset_cls.__name__}` instead." - ) - - tokenizer = load_serialized_object(preset, TOKENIZER_CONFIG_FILE) - tokenizer.load_preset_assets(preset) - return tokenizer + loader = get_preset_loader(preset) + backbone_cls = loader.check_backbone_class() + if cls.backbone_cls != backbone_cls: + cls = find_subclass(preset, cls, backbone_cls) + return loader.load_tokenizer(cls, **kwargs) diff --git a/keras_nlp/src/tokenizers/tokenizer_test.py b/keras_nlp/src/tokenizers/tokenizer_test.py index 6340982e77..9e16c694bf 100644 --- a/keras_nlp/src/tokenizers/tokenizer_test.py +++ b/keras_nlp/src/tokenizers/tokenizer_test.py @@ -27,6 +27,7 @@ from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE from keras_nlp.src.utils.preset_utils import check_config_class +from keras_nlp.src.utils.preset_utils import load_json class SimpleTokenizer(Tokenizer): @@ -64,7 +65,7 @@ def test_from_preset_errors(self): GPT2Tokenizer.from_preset("bert_tiny_en_uncased") with self.assertRaises(ValueError): # No loading on a non-keras model. - Tokenizer.from_preset("hf://google-bert/bert-base-uncased") + Tokenizer.from_preset("hf://spacy/en_core_web_sm") def test_tokenize(self): input_data = ["the quick brown fox"] @@ -118,6 +119,5 @@ def test_save_to_preset(self, cls, preset_name, tokenizer_type): self.assertEqual(set(tokenizer.file_assets), set(expected_assets)) # Check config class. - self.assertEqual( - cls, check_config_class(save_dir, TOKENIZER_CONFIG_FILE) - ) + tokenizer_config = load_json(save_dir, TOKENIZER_CONFIG_FILE) + self.assertEqual(cls, check_config_class(tokenizer_config)) diff --git a/keras_nlp/src/utils/preset_utils.py b/keras_nlp/src/utils/preset_utils.py index a46a73bb6f..5da5eb6feb 100644 --- a/keras_nlp/src/utils/preset_utils.py +++ b/keras_nlp/src/utils/preset_utils.py @@ -107,6 +107,26 @@ def list_subclasses(cls): return subclasses +def find_subclass(preset, cls, backbone_cls): + """Find a subclass that is compatible with backbone_cls.""" + subclasses = list_subclasses(cls) + subclasses = filter(lambda x: x.backbone_cls == backbone_cls, subclasses) + subclasses = list(subclasses) + if not subclasses: + raise ValueError( + f"Unable to find a subclass of {cls.__name__} that is compatible " + f"with {backbone_cls.__name__} found in preset '{preset}'." + ) + # If we find multiple subclasses, try to filter to direct subclasses of + # the class we are trying to instantiate. + if len(subclasses) > 1: + directs = list(filter(lambda x: x in cls.__bases__, subclasses)) + if len(directs) > 1: + subclasses = directs + # Return the subclass that was registered first (prefer built in classes). + return subclasses[0] + + def get_file(preset, path): """Download a preset file in necessary and return the local path.""" # TODO: Add tests for FileNotFound exceptions. @@ -272,6 +292,7 @@ def recursive_pop(config, key): recursive_pop(value, key) +# TODO: refactor saving routines into a PresetSaver class? def make_preset_dir(preset): os.makedirs(preset, exist_ok=True) @@ -377,7 +398,7 @@ def _validate_backbone(preset): ) -def get_snake_case(name): +def to_snake_case(name): name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() @@ -386,7 +407,7 @@ def create_model_card(preset): model_card_path = os.path.join(preset, README_FILE) markdown_content = "" - config = load_config(preset, CONFIG_FILE) + config = load_json(preset, CONFIG_FILE) model_name = ( config["class_name"].replace("Backbone", "") if config["class_name"].endswith("Backbone") @@ -395,7 +416,7 @@ def create_model_card(preset): task_type = None if check_file_exists(preset, TASK_CONFIG_FILE): - task_config = load_config(preset, TASK_CONFIG_FILE) + task_config = load_json(preset, TASK_CONFIG_FILE) task_type = ( task_config["class_name"].replace(model_name, "") if task_config["class_name"].startswith(model_name) @@ -412,7 +433,7 @@ def create_model_card(preset): markdown_content += "---\n" model_link = ( - f"https://keras.io/api/keras_nlp/models/{get_snake_case(model_name)}" + f"https://keras.io/api/keras_nlp/models/{to_snake_case(model_name)}" ) markdown_content += ( f"This is a [`{model_name}` model]({model_link}) " @@ -533,38 +554,14 @@ def upload_preset( ) -def load_config(preset, config_file=CONFIG_FILE): +def load_json(preset, config_file=CONFIG_FILE): config_path = get_file(preset, config_file) with open(config_path, encoding="utf-8") as config_file: config = json.load(config_file) return config -def check_format(preset): - if check_file_exists(preset, SAFETENSOR_FILE) or check_file_exists( - preset, SAFETENSOR_CONFIG_FILE - ): - return "transformers" - - if not check_file_exists(preset, METADATA_FILE): - raise FileNotFoundError( - f"The preset directory `{preset}` doesn't have a file named `{METADATA_FILE}`, " - "or you do not have access to it. This file is required to load a Keras model " - "preset. Please verify that the model you are trying to load is a Keras model." - ) - metadata = load_config(preset, METADATA_FILE) - if "keras_version" not in metadata: - raise ValueError( - f"`{METADATA_FILE}` in the preset directory `{preset}` doesn't have `keras_version`. " - "Please verify that the model you are trying to load is a Keras model." - ) - return "keras" - - -def load_serialized_object(preset, config_file=CONFIG_FILE, **kwargs): - kwargs = kwargs or {} - config = load_config(preset, config_file) - +def load_serialized_object(config, **kwargs): # `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`. # Ensure that `dtype` is properly configured. dtype = kwargs.pop("dtype", None) @@ -574,14 +571,8 @@ def load_serialized_object(preset, config_file=CONFIG_FILE, **kwargs): return keras.saving.deserialize_keras_object(config) -def check_config_class( - preset, - config_file=CONFIG_FILE, -): +def check_config_class(config): """Validate a preset is being loaded on the correct class.""" - config_path = get_file(preset, config_file) - with open(config_path, encoding="utf-8") as config_file: - config = json.load(config_file) return keras.saving.get_registered_object(config["registered_name"]) @@ -615,3 +606,134 @@ def set_dtype_in_config(config, dtype=None): for k in policy_map_config["policy_map"].keys(): policy_map_config["policy_map"][k]["config"]["source_name"] = dtype return config + + +def get_preset_loader(preset): + # Avoid circular import. + from keras_nlp.src.utils.transformers.preset_loader import ( + TransformersPresetLoader, + ) + + if not check_file_exists(preset, CONFIG_FILE): + raise ValueError( + f"Preset {preset} has no {CONFIG_FILE}. Make sure the URI or " + "directory you are trying to load is a valid KerasNLP preset and " + "and that you have permissions to read/download from this location." + ) + # We currently assume all formats we support have a `config.json`, this is + # true, for Keras, Transformers, and timm. We infer the on disk format by + # inspecting the `config.json` file. + config = load_json(preset, CONFIG_FILE) + if "registered_name" in config: + # If we see registered_name, we assume a serialized Keras object. + return KerasPresetLoader(preset, config) + elif "model_type" in config: + # If we see model_type, we assume a Transformers style config. + return TransformersPresetLoader(preset, config) + else: + contents = json.dumps(config, indent=4) + raise ValueError( + f"Unrecognized format for {CONFIG_FILE} in {preset}. " + "Create a preset with the `save_to_preset` utility on KerasNLP " + f"models. Contents of {CONFIG_FILE}:\n{contents}" + ) + + +class PresetLoader: + def __init__(self, preset, config): + self.config = config + self.preset = preset + + def check_backbone_class(self): + """Infer the backbone architecture.""" + raise NotImplementedError + + def load_backbone(self, cls, load_weights, **kwargs): + """Load the backbone model from the preset.""" + raise NotImplementedError + + def load_tokenizer(self, cls, **kwargs): + """Load a tokenizer layer from the preset.""" + raise NotImplementedError + + def load_task(self, cls, load_weights, **kwargs): + """Load a task model from the preset. + + By default, we create a task from a backbone and preprocessor with + default arguments. This means + """ + if "backbone" not in kwargs: + backbone_class = cls.backbone_cls + # Forward dtype to backbone. + backbone_kwargs = {"dtype": kwargs.pop("dtype", None)} + kwargs["backbone"] = self.load_backbone( + backbone_class, load_weights, **backbone_kwargs + ) + if "preprocessor" not in kwargs: + kwargs["preprocessor"] = self.load_preprocessor( + cls.preprocessor_cls + ) + return cls(**kwargs) + + def load_preprocessor(self, cls, **kwargs): + """Load a prepocessor layer from the preset. + + By default, we create a preprocessor from a tokenizer with default + arguments. This allow us to support transformers checkpoints by + only converting the backbone and tokenizer. + """ + if "tokenizer" not in kwargs: + kwargs["tokenizer"] = self.load_tokenizer(cls.tokenizer_cls) + return cls(**kwargs) + + +class KerasPresetLoader(PresetLoader): + def check_backbone_class(self): + return check_config_class(self.config) + + def load_backbone(self, cls, load_weights, **kwargs): + backbone = load_serialized_object(self.config, **kwargs) + if load_weights: + jax_memory_cleanup(backbone) + backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE)) + return backbone + + def load_tokenizer(self, cls, **kwargs): + tokenizer_config = load_json(self.preset, TOKENIZER_CONFIG_FILE) + tokenizer = load_serialized_object(tokenizer_config, **kwargs) + tokenizer.load_preset_assets(self.preset) + return tokenizer + + def load_task(self, cls, load_weights, **kwargs): + # If there is no `task.json` or it's for the wrong class delegate to the + # super class loader. + if not check_file_exists(self.preset, TASK_CONFIG_FILE): + return super().load_task(cls, load_weights, **kwargs) + task_config = load_json(self.preset, TASK_CONFIG_FILE) + if not issubclass(check_config_class(task_config), cls): + return super().load_task(cls, load_weights, **kwargs) + # We found a `task.json` with a complete config for our class. + task = load_serialized_object(task_config, **kwargs) + if task.preprocessor is not None: + task.preprocessor.tokenizer.load_preset_assets(self.preset) + if load_weights: + jax_memory_cleanup(task) + if check_file_exists(self.preset, TASK_WEIGHTS_FILE): + task_weights = get_file(self.preset, TASK_WEIGHTS_FILE) + task.load_task_weights(task_weights) + backbone_weights = get_file(self.preset, MODEL_WEIGHTS_FILE) + task.backbone.load_weights(backbone_weights) + return task + + def load_preprocessor(self, cls, **kwargs): + # If there is no `preprocessing.json` or it's for the wrong class, + # delegate to the super class loader. + if not check_file_exists(self.preset, PREPROCESSOR_CONFIG_FILE): + return super().load_preprocessor(cls, **kwargs) + preprocessor_json = load_json(self.preset, PREPROCESSOR_CONFIG_FILE) + if not issubclass(check_config_class(preprocessor_json), cls): + return super().load_preprocessor(cls, **kwargs) + # We found a `preprocessing.json` with a complete config for our class. + preprocessor = load_serialized_object(preprocessor_json, **kwargs) + preprocessor.tokenizer.load_preset_assets(self.preset) + return preprocessor diff --git a/keras_nlp/src/utils/preset_utils_test.py b/keras_nlp/src/utils/preset_utils_test.py index 8a141534fa..3332001023 100644 --- a/keras_nlp/src/utils/preset_utils_test.py +++ b/keras_nlp/src/utils/preset_utils_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os import pytest @@ -24,9 +23,7 @@ from keras_nlp.src.tests.test_case import TestCase from keras_nlp.src.utils.keras_utils import has_quantization_support from keras_nlp.src.utils.preset_utils import CONFIG_FILE -from keras_nlp.src.utils.preset_utils import METADATA_FILE from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE -from keras_nlp.src.utils.preset_utils import check_format from keras_nlp.src.utils.preset_utils import load_serialized_object from keras_nlp.src.utils.preset_utils import upload_preset @@ -95,27 +92,6 @@ def test_upload_with_invalid_json(self, json_file): with self.assertRaisesRegex(ValueError, "is an invalid json"): upload_preset("kaggle://test/test/test", local_preset_dir) - def test_missing_metadata(self): - temp_dir = self.get_temp_dir() - preset_dir = os.path.join(temp_dir, "test_missing_metadata") - os.mkdir(preset_dir) - with self.assertRaisesRegex( - FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`" - ): - check_format(preset_dir) - - def test_incorrect_metadata(self): - temp_dir = self.get_temp_dir() - preset_dir = os.path.join(temp_dir, "test_incorrect_metadata") - os.mkdir(preset_dir) - json_path = os.path.join(preset_dir, METADATA_FILE) - data = {"key": "value"} - with open(json_path, "w") as f: - json.dump(data, f) - - with self.assertRaisesRegex(ValueError, "doesn't have `keras_version`"): - check_format(preset_dir) - @parameterized.named_parameters( ("gemma2_2b_en", "gemma2_2b_en", "bfloat16", False), ("llama2_7b_en_int8", "llama2_7b_en_int8", "bfloat16", True), diff --git a/keras_nlp/src/utils/transformers/convert.py b/keras_nlp/src/utils/transformers/convert.py deleted file mode 100644 index 522a534e93..0000000000 --- a/keras_nlp/src/utils/transformers/convert.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright 2024 The KerasNLP Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert huggingface models to KerasNLP.""" - - -from keras_nlp.src.utils.transformers.convert_albert import load_albert_backbone -from keras_nlp.src.utils.transformers.convert_albert import ( - load_albert_tokenizer, -) -from keras_nlp.src.utils.transformers.convert_bart import load_bart_backbone -from keras_nlp.src.utils.transformers.convert_bart import load_bart_tokenizer -from keras_nlp.src.utils.transformers.convert_bert import load_bert_backbone -from keras_nlp.src.utils.transformers.convert_bert import load_bert_tokenizer -from keras_nlp.src.utils.transformers.convert_distilbert import ( - load_distilbert_backbone, -) -from keras_nlp.src.utils.transformers.convert_distilbert import ( - load_distilbert_tokenizer, -) -from keras_nlp.src.utils.transformers.convert_gemma import load_gemma_backbone -from keras_nlp.src.utils.transformers.convert_gemma import load_gemma_tokenizer -from keras_nlp.src.utils.transformers.convert_gpt2 import load_gpt2_backbone -from keras_nlp.src.utils.transformers.convert_gpt2 import load_gpt2_tokenizer -from keras_nlp.src.utils.transformers.convert_llama3 import load_llama3_backbone -from keras_nlp.src.utils.transformers.convert_llama3 import ( - load_llama3_tokenizer, -) -from keras_nlp.src.utils.transformers.convert_mistral import ( - load_mistral_backbone, -) -from keras_nlp.src.utils.transformers.convert_mistral import ( - load_mistral_tokenizer, -) -from keras_nlp.src.utils.transformers.convert_pali_gemma import ( - load_pali_gemma_backbone, -) -from keras_nlp.src.utils.transformers.convert_pali_gemma import ( - load_pali_gemma_tokenizer, -) - - -def load_transformers_backbone(cls, preset, load_weights): - """ - Load a Transformer model config and weights as a KerasNLP backbone. - - Args: - cls (class): Keras model class. - preset (str): Preset configuration name. - load_weights (bool): Whether to load the weights. - - Returns: - backbone: Initialized Keras model backbone. - """ - if cls is None: - raise ValueError("Backbone class is None") - if cls.__name__ == "BertBackbone": - return load_bert_backbone(cls, preset, load_weights) - if cls.__name__ == "GemmaBackbone": - return load_gemma_backbone(cls, preset, load_weights) - if cls.__name__ == "Llama3Backbone": - return load_llama3_backbone(cls, preset, load_weights) - if cls.__name__ == "PaliGemmaBackbone": - return load_pali_gemma_backbone(cls, preset, load_weights) - if cls.__name__ == "GPT2Backbone": - return load_gpt2_backbone(cls, preset, load_weights) - if cls.__name__ == "DistilBertBackbone": - return load_distilbert_backbone(cls, preset, load_weights) - if cls.__name__ == "AlbertBackbone": - return load_albert_backbone(cls, preset, load_weights) - if cls.__name__ == "BartBackbone": - return load_bart_backbone(cls, preset, load_weights) - if cls.__name__ == "MistralBackbone": - return load_mistral_backbone(cls, preset, load_weights) - raise ValueError( - f"{cls} has not been ported from the Hugging Face format yet. " - "Please check Hugging Face Hub for the Keras model. " - ) - - -def load_transformers_tokenizer(cls, preset): - """ - Load a Transformer tokenizer assets as a KerasNLP tokenizer. - - Args: - cls (class): Tokenizer class. - preset (str): Preset configuration name. - - Returns: - tokenizer: Initialized tokenizer. - """ - if cls is None: - raise ValueError("Tokenizer class is None") - if cls.__name__ == "BertTokenizer": - return load_bert_tokenizer(cls, preset) - if cls.__name__ == "GemmaTokenizer": - return load_gemma_tokenizer(cls, preset) - if cls.__name__ == "Llama3Tokenizer": - return load_llama3_tokenizer(cls, preset) - if cls.__name__ == "PaliGemmaTokenizer": - return load_pali_gemma_tokenizer(cls, preset) - if cls.__name__ == "GPT2Tokenizer": - return load_gpt2_tokenizer(cls, preset) - if cls.__name__ == "DistilBertTokenizer": - return load_distilbert_tokenizer(cls, preset) - if cls.__name__ == "AlbertTokenizer": - return load_albert_tokenizer(cls, preset) - if cls.__name__ == "BartTokenizer": - return load_bart_tokenizer(cls, preset) - if cls.__name__ == "MistralTokenizer": - return load_mistral_tokenizer(cls, preset) - raise ValueError( - f"{cls} has not been ported from the Hugging Face format yet. " - "Please check Hugging Face Hub for the Keras model. " - ) diff --git a/keras_nlp/src/utils/transformers/convert_albert.py b/keras_nlp/src/utils/transformers/convert_albert.py index f1338bf68e..171749fe2d 100644 --- a/keras_nlp/src/utils/transformers/convert_albert.py +++ b/keras_nlp/src/utils/transformers/convert_albert.py @@ -13,11 +13,10 @@ # limitations under the License. import numpy as np -from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE +from keras_nlp.src.models.albert.albert_backbone import AlbertBackbone from keras_nlp.src.utils.preset_utils import get_file -from keras_nlp.src.utils.preset_utils import jax_memory_cleanup -from keras_nlp.src.utils.preset_utils import load_config -from keras_nlp.src.utils.transformers.safetensor_utils import SafetensorLoader + +backbone_cls = AlbertBackbone def convert_backbone_config(transformers_config): @@ -36,7 +35,7 @@ def convert_backbone_config(transformers_config): } -def convert_weights(backbone, loader): +def convert_weights(backbone, loader, transformers_config): # Embeddings loader.port_weight( keras_variable=backbone.token_embedding.embeddings, @@ -189,19 +188,6 @@ def convert_weights(backbone, loader): hf_weight_key="albert.pooler.bias", ) - return backbone - - -def load_albert_backbone(cls, preset, load_weights): - transformers_config = load_config(preset, HF_CONFIG_FILE) - keras_config = convert_backbone_config(transformers_config) - backbone = cls(**keras_config) - if load_weights: - jax_memory_cleanup(backbone) - with SafetensorLoader(preset) as loader: - convert_weights(backbone, loader) - return backbone - -def load_albert_tokenizer(cls, preset): - return cls(get_file(preset, "spiece.model")) +def convert_tokenizer(cls, preset, **kwargs): + return cls(get_file(preset, "spiece.model"), **kwargs) diff --git a/keras_nlp/src/utils/transformers/convert_albert_test.py b/keras_nlp/src/utils/transformers/convert_albert_test.py index 520ce41889..91576bc7af 100644 --- a/keras_nlp/src/utils/transformers/convert_albert_test.py +++ b/keras_nlp/src/utils/transformers/convert_albert_test.py @@ -13,7 +13,10 @@ # limitations under the License. import pytest +from keras_nlp.src.models.albert.albert_classifier import AlbertBackbone from keras_nlp.src.models.albert.albert_classifier import AlbertClassifier +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.classifier import Classifier from keras_nlp.src.tests.test_case import TestCase @@ -26,4 +29,18 @@ def test_convert_tiny_preset(self): prompt = "That movies was terrible." model.predict([prompt]) + @pytest.mark.large + def test_class_detection(self): + model = Classifier.from_preset( + "hf://albert/albert-base-v2", + num_classes=2, + load_weights=False, + ) + self.assertIsInstance(model, AlbertClassifier) + model = Backbone.from_preset( + "hf://albert/albert-base-v2", + load_weights=False, + ) + self.assertIsInstance(model, AlbertBackbone) + # TODO: compare numerics with huggingface model diff --git a/keras_nlp/src/utils/transformers/convert_bart.py b/keras_nlp/src/utils/transformers/convert_bart.py index 69e425c638..c004c2898b 100644 --- a/keras_nlp/src/utils/transformers/convert_bart.py +++ b/keras_nlp/src/utils/transformers/convert_bart.py @@ -13,11 +13,10 @@ # limitations under the License. import numpy as np -from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE +from keras_nlp.src.models.bart.bart_backbone import BartBackbone from keras_nlp.src.utils.preset_utils import get_file -from keras_nlp.src.utils.preset_utils import jax_memory_cleanup -from keras_nlp.src.utils.preset_utils import load_config -from keras_nlp.src.utils.transformers.safetensor_utils import SafetensorLoader + +backbone_cls = BartBackbone def convert_backbone_config(transformers_config): @@ -32,7 +31,7 @@ def convert_backbone_config(transformers_config): } -def convert_weights(backbone, loader): +def convert_weights(backbone, loader, transformers_config): # Embeddings loader.port_weight( keras_variable=backbone.token_embedding.embeddings, @@ -363,24 +362,12 @@ def convert_weights(backbone, loader): hf_weight_key="decoder.layernorm_embedding.bias", ) - return backbone - - -def load_bart_backbone(cls, preset, load_weights): - transformers_config = load_config(preset, HF_CONFIG_FILE) - keras_config = convert_backbone_config(transformers_config) - backbone = cls(**keras_config) - if load_weights: - jax_memory_cleanup(backbone) - with SafetensorLoader(preset) as loader: - convert_weights(backbone, loader) - return backbone - -def load_bart_tokenizer(cls, preset): +def convert_tokenizer(cls, preset, **kwargs): vocab_file = get_file(preset, "vocab.json") merges_file = get_file(preset, "merges.txt") return cls( vocabulary=vocab_file, merges=merges_file, + **kwargs, ) diff --git a/keras_nlp/src/utils/transformers/convert_bart_test.py b/keras_nlp/src/utils/transformers/convert_bart_test.py index ca17d17431..0e7aa37f37 100644 --- a/keras_nlp/src/utils/transformers/convert_bart_test.py +++ b/keras_nlp/src/utils/transformers/convert_bart_test.py @@ -13,7 +13,10 @@ # limitations under the License. import pytest +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.bart.bart_backbone import BartBackbone from keras_nlp.src.models.bart.bart_seq_2_seq_lm import BartSeq2SeqLM +from keras_nlp.src.models.seq_2_seq_lm import Seq2SeqLM from keras_nlp.src.tests.test_case import TestCase @@ -24,4 +27,17 @@ def test_convert_tiny_preset(self): prompt = "What is your favorite condiment?" model.generate([prompt], max_length=15) + @pytest.mark.large + def test_class_detection(self): + model = Seq2SeqLM.from_preset( + "hf://cosmo3769/tiny-bart-test", + load_weights=False, + ) + self.assertIsInstance(model, BartSeq2SeqLM) + model = Backbone.from_preset( + "hf://cosmo3769/tiny-bart-test", + load_weights=False, + ) + self.assertIsInstance(model, BartBackbone) + # TODO: compare numerics with huggingface model diff --git a/keras_nlp/src/utils/transformers/convert_bert.py b/keras_nlp/src/utils/transformers/convert_bert.py index 7b64affbea..02caca3ba1 100644 --- a/keras_nlp/src/utils/transformers/convert_bert.py +++ b/keras_nlp/src/utils/transformers/convert_bert.py @@ -13,12 +13,12 @@ # limitations under the License. import numpy as np -from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE +from keras_nlp.src.models.bert.bert_backbone import BertBackbone from keras_nlp.src.utils.preset_utils import HF_TOKENIZER_CONFIG_FILE from keras_nlp.src.utils.preset_utils import get_file -from keras_nlp.src.utils.preset_utils import jax_memory_cleanup -from keras_nlp.src.utils.preset_utils import load_config -from keras_nlp.src.utils.transformers.safetensor_utils import SafetensorLoader +from keras_nlp.src.utils.preset_utils import load_json + +backbone_cls = BertBackbone def convert_backbone_config(transformers_config): @@ -154,20 +154,10 @@ def transpose_and_reshape(x, shape): ) -def load_bert_backbone(cls, preset, load_weights): - transformers_config = load_config(preset, HF_CONFIG_FILE) - keras_config = convert_backbone_config(transformers_config) - backbone = cls(**keras_config) - if load_weights: - jax_memory_cleanup(backbone) - with SafetensorLoader(preset) as loader: - convert_weights(backbone, loader, transformers_config) - return backbone - - -def load_bert_tokenizer(cls, preset): - transformers_config = load_config(preset, HF_TOKENIZER_CONFIG_FILE) +def convert_tokenizer(cls, preset, **kwargs): + transformers_config = load_json(preset, HF_TOKENIZER_CONFIG_FILE) return cls( get_file(preset, "vocab.txt"), lowercase=transformers_config["do_lower_case"], + **kwargs, ) diff --git a/keras_nlp/src/utils/transformers/convert_bert_test.py b/keras_nlp/src/utils/transformers/convert_bert_test.py index 1f4ee37928..2d4d306ef1 100644 --- a/keras_nlp/src/utils/transformers/convert_bert_test.py +++ b/keras_nlp/src/utils/transformers/convert_bert_test.py @@ -13,7 +13,10 @@ # limitations under the License. import pytest +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.bert.bert_backbone import BertBackbone from keras_nlp.src.models.bert.bert_classifier import BertClassifier +from keras_nlp.src.models.classifier import Classifier from keras_nlp.src.tests.test_case import TestCase @@ -26,4 +29,18 @@ def test_convert_tiny_preset(self): prompt = "That movies was terrible." model.predict([prompt]) + @pytest.mark.large + def test_class_detection(self): + model = Classifier.from_preset( + "hf://google-bert/bert-base-uncased", + num_classes=2, + load_weights=False, + ) + self.assertIsInstance(model, BertClassifier) + model = Backbone.from_preset( + "hf://google-bert/bert-base-uncased", + load_weights=False, + ) + self.assertIsInstance(model, BertBackbone) + # TODO: compare numerics with huggingface model diff --git a/keras_nlp/src/utils/transformers/convert_distilbert.py b/keras_nlp/src/utils/transformers/convert_distilbert.py index 8ed1b2fc83..240763bd8c 100644 --- a/keras_nlp/src/utils/transformers/convert_distilbert.py +++ b/keras_nlp/src/utils/transformers/convert_distilbert.py @@ -13,12 +13,14 @@ # limitations under the License. import numpy as np -from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE +from keras_nlp.src.models.distil_bert.distil_bert_backbone import ( + DistilBertBackbone, +) from keras_nlp.src.utils.preset_utils import HF_TOKENIZER_CONFIG_FILE from keras_nlp.src.utils.preset_utils import get_file -from keras_nlp.src.utils.preset_utils import jax_memory_cleanup -from keras_nlp.src.utils.preset_utils import load_config -from keras_nlp.src.utils.transformers.safetensor_utils import SafetensorLoader +from keras_nlp.src.utils.preset_utils import load_json + +backbone_cls = DistilBertBackbone def convert_backbone_config(transformers_config): @@ -33,7 +35,7 @@ def convert_backbone_config(transformers_config): } -def convert_weights(backbone, loader): +def convert_weights(backbone, loader, transformers_config): # Embeddings loader.port_weight( keras_variable=backbone.get_layer( @@ -162,23 +164,11 @@ def convert_weights(backbone, loader): hf_weight_key="distilbert.embeddings.LayerNorm.bias", ) - return backbone - - -def load_distilbert_backbone(cls, preset, load_weights): - transformers_config = load_config(preset, HF_CONFIG_FILE) - keras_config = convert_backbone_config(transformers_config) - backbone = cls(**keras_config) - if load_weights: - jax_memory_cleanup(backbone) - with SafetensorLoader(preset) as loader: - convert_weights(backbone, loader) - return backbone - -def load_distilbert_tokenizer(cls, preset): - transformers_config = load_config(preset, HF_TOKENIZER_CONFIG_FILE) +def convert_tokenizer(cls, preset, **kwargs): + transformers_config = load_json(preset, HF_TOKENIZER_CONFIG_FILE) return cls( get_file(preset, "vocab.txt"), lowercase=transformers_config["do_lower_case"], + **kwargs, ) diff --git a/keras_nlp/src/utils/transformers/convert_distilbert_test.py b/keras_nlp/src/utils/transformers/convert_distilbert_test.py index ea62d14ca0..297a45bfbf 100644 --- a/keras_nlp/src/utils/transformers/convert_distilbert_test.py +++ b/keras_nlp/src/utils/transformers/convert_distilbert_test.py @@ -13,6 +13,11 @@ # limitations under the License. import pytest +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.classifier import Classifier +from keras_nlp.src.models.distil_bert.distil_bert_backbone import ( + DistilBertBackbone, +) from keras_nlp.src.models.distil_bert.distil_bert_classifier import ( DistilBertClassifier, ) @@ -28,4 +33,18 @@ def test_convert_tiny_preset(self): prompt = "That movies was terrible." model.predict([prompt]) + @pytest.mark.large + def test_class_detection(self): + model = Classifier.from_preset( + "hf://distilbert/distilbert-base-uncased", + num_classes=2, + load_weights=False, + ) + self.assertIsInstance(model, DistilBertClassifier) + model = Backbone.from_preset( + "hf://distilbert/distilbert-base-uncased", + load_weights=False, + ) + self.assertIsInstance(model, DistilBertBackbone) + # TODO: compare numerics with huggingface model diff --git a/keras_nlp/src/utils/transformers/convert_gemma.py b/keras_nlp/src/utils/transformers/convert_gemma.py index 5000bde750..7eab62b17c 100644 --- a/keras_nlp/src/utils/transformers/convert_gemma.py +++ b/keras_nlp/src/utils/transformers/convert_gemma.py @@ -13,11 +13,10 @@ # limitations under the License. import numpy as np -from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE +from keras_nlp.src.models.gemma.gemma_backbone import GemmaBackbone from keras_nlp.src.utils.preset_utils import get_file -from keras_nlp.src.utils.preset_utils import jax_memory_cleanup -from keras_nlp.src.utils.preset_utils import load_config -from keras_nlp.src.utils.transformers.safetensor_utils import SafetensorLoader + +backbone_cls = GemmaBackbone def convert_backbone_config(transformers_config): @@ -169,19 +168,6 @@ def convert_weights(backbone, loader, transformers_config): hf_weight_key="model.norm.weight", ) - return backbone - - -def load_gemma_backbone(cls, preset, load_weights): - transformers_config = load_config(preset, HF_CONFIG_FILE) - keras_config = convert_backbone_config(transformers_config) - backbone = cls(**keras_config) - if load_weights: - jax_memory_cleanup(backbone) - with SafetensorLoader(preset) as loader: - convert_weights(backbone, loader, transformers_config) - return backbone - -def load_gemma_tokenizer(cls, preset): - return cls(get_file(preset, "tokenizer.model")) +def convert_tokenizer(cls, preset, **kwargs): + return cls(get_file(preset, "tokenizer.model"), **kwargs) diff --git a/keras_nlp/src/utils/transformers/convert_gemma_test.py b/keras_nlp/src/utils/transformers/convert_gemma_test.py index 2347aa571d..fd0900b9f7 100644 --- a/keras_nlp/src/utils/transformers/convert_gemma_test.py +++ b/keras_nlp/src/utils/transformers/convert_gemma_test.py @@ -13,6 +13,9 @@ # limitations under the License. import pytest +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.causal_lm import CausalLM +from keras_nlp.src.models.gemma.gemma_backbone import GemmaBackbone from keras_nlp.src.models.gemma.gemma_causal_lm import GemmaCausalLM from keras_nlp.src.tests.test_case import TestCase @@ -28,4 +31,17 @@ def test_convert_tiny_preset(self): prompt = "What is your favorite condiment?" model.generate([prompt], max_length=15) + @pytest.mark.large + def test_class_detection(self): + model = CausalLM.from_preset( + "hf://ariG23498/tiny-gemma-test", + load_weights=False, + ) + self.assertIsInstance(model, GemmaCausalLM) + model = Backbone.from_preset( + "hf://ariG23498/tiny-gemma-test", + load_weights=False, + ) + self.assertIsInstance(model, GemmaBackbone) + # TODO: compare numerics with huggingface model diff --git a/keras_nlp/src/utils/transformers/convert_gpt2.py b/keras_nlp/src/utils/transformers/convert_gpt2.py index 2ac8a9a8a2..73bc596905 100644 --- a/keras_nlp/src/utils/transformers/convert_gpt2.py +++ b/keras_nlp/src/utils/transformers/convert_gpt2.py @@ -13,11 +13,10 @@ # limitations under the License. import numpy as np -from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE +from keras_nlp.src.models.gpt2.gpt2_backbone import GPT2Backbone from keras_nlp.src.utils.preset_utils import get_file -from keras_nlp.src.utils.preset_utils import jax_memory_cleanup -from keras_nlp.src.utils.preset_utils import load_config -from keras_nlp.src.utils.transformers.safetensor_utils import SafetensorLoader + +backbone_cls = GPT2Backbone def convert_backbone_config(transformers_config): @@ -163,24 +162,12 @@ def convert_weights(backbone, loader, transformers_config): hf_weight_key="ln_f.bias", ) - return backbone - - -def load_gpt2_backbone(cls, preset, load_weights): - transformers_config = load_config(preset, HF_CONFIG_FILE) - keras_config = convert_backbone_config(transformers_config) - backbone = cls(**keras_config) - if load_weights: - jax_memory_cleanup(backbone) - with SafetensorLoader(preset) as loader: - convert_weights(backbone, loader, transformers_config) - return backbone - -def load_gpt2_tokenizer(cls, preset): +def convert_tokenizer(cls, preset, **kwargs): vocab_file = get_file(preset, "vocab.json") merges_file = get_file(preset, "merges.txt") return cls( vocabulary=vocab_file, merges=merges_file, + **kwargs, ) diff --git a/keras_nlp/src/utils/transformers/convert_gpt2_test.py b/keras_nlp/src/utils/transformers/convert_gpt2_test.py index c7b65eb87f..68fd0ef77a 100644 --- a/keras_nlp/src/utils/transformers/convert_gpt2_test.py +++ b/keras_nlp/src/utils/transformers/convert_gpt2_test.py @@ -13,6 +13,9 @@ # limitations under the License. import pytest +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.causal_lm import CausalLM +from keras_nlp.src.models.gpt2.gpt2_backbone import GPT2Backbone from keras_nlp.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM from keras_nlp.src.tests.test_case import TestCase @@ -24,4 +27,17 @@ def test_convert_tiny_preset(self): prompt = "What is your favorite condiment?" model.generate([prompt], max_length=15) + @pytest.mark.large + def test_class_detection(self): + model = CausalLM.from_preset( + "hf://openai-community/gpt2", + load_weights=False, + ) + self.assertIsInstance(model, GPT2CausalLM) + model = Backbone.from_preset( + "hf://openai-community/gpt2", + load_weights=False, + ) + self.assertIsInstance(model, GPT2Backbone) + # TODO: compare numerics with huggingface model diff --git a/keras_nlp/src/utils/transformers/convert_llama3.py b/keras_nlp/src/utils/transformers/convert_llama3.py index 0d248d4b6a..3402aff55d 100644 --- a/keras_nlp/src/utils/transformers/convert_llama3.py +++ b/keras_nlp/src/utils/transformers/convert_llama3.py @@ -13,10 +13,10 @@ # limitations under the License. import numpy as np -from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE -from keras_nlp.src.utils.preset_utils import jax_memory_cleanup -from keras_nlp.src.utils.preset_utils import load_config -from keras_nlp.src.utils.transformers.safetensor_utils import SafetensorLoader +from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone +from keras_nlp.src.utils.preset_utils import load_json + +backbone_cls = Llama3Backbone def convert_backbone_config(transformers_config): @@ -111,19 +111,8 @@ def transpose_and_reshape(x, shape): return backbone -def load_llama3_backbone(cls, preset, load_weights): - transformers_config = load_config(preset, HF_CONFIG_FILE) - keras_config = convert_backbone_config(transformers_config) - backbone = cls(**keras_config) - if load_weights: - jax_memory_cleanup(backbone) - with SafetensorLoader(preset) as loader: - convert_weights(backbone, loader, transformers_config) - return backbone - - -def load_llama3_tokenizer(cls, preset): - tokenizer_config = load_config(preset, "tokenizer.json") +def convert_tokenizer(cls, preset, **kwargs): + tokenizer_config = load_json(preset, "tokenizer.json") vocab = tokenizer_config["model"]["vocab"] merges = tokenizer_config["model"]["merges"] @@ -133,4 +122,4 @@ def load_llama3_tokenizer(cls, preset): vocab[bot["content"]] = bot["id"] vocab[eot["content"]] = eot["id"] - return cls(vocabulary=vocab, merges=merges) + return cls(vocabulary=vocab, merges=merges, **kwargs) diff --git a/keras_nlp/src/utils/transformers/convert_llama3_test.py b/keras_nlp/src/utils/transformers/convert_llama3_test.py index d27bc31e6f..8b78988c49 100644 --- a/keras_nlp/src/utils/transformers/convert_llama3_test.py +++ b/keras_nlp/src/utils/transformers/convert_llama3_test.py @@ -13,6 +13,9 @@ # limitations under the License. import pytest +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.causal_lm import CausalLM +from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone from keras_nlp.src.models.llama3.llama3_causal_lm import Llama3CausalLM from keras_nlp.src.tests.test_case import TestCase @@ -24,4 +27,17 @@ def test_convert_tiny_preset(self): prompt = "What is your favorite condiment?" model.generate([prompt], max_length=15) + @pytest.mark.large + def test_class_detection(self): + model = CausalLM.from_preset( + "hf://ariG23498/tiny-llama3-test", + load_weights=False, + ) + self.assertIsInstance(model, Llama3CausalLM) + model = Backbone.from_preset( + "hf://ariG23498/tiny-llama3-test", + load_weights=False, + ) + self.assertIsInstance(model, Llama3Backbone) + # TODO: compare numerics with huggingface model diff --git a/keras_nlp/src/utils/transformers/convert_mistral.py b/keras_nlp/src/utils/transformers/convert_mistral.py index 5a8b989a4a..df1ec79824 100644 --- a/keras_nlp/src/utils/transformers/convert_mistral.py +++ b/keras_nlp/src/utils/transformers/convert_mistral.py @@ -13,11 +13,10 @@ # limitations under the License. import numpy as np -from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE +from keras_nlp.src.models.mistral.mistral_backbone import MistralBackbone from keras_nlp.src.utils.preset_utils import get_file -from keras_nlp.src.utils.preset_utils import jax_memory_cleanup -from keras_nlp.src.utils.preset_utils import load_config -from keras_nlp.src.utils.transformers.safetensor_utils import SafetensorLoader + +backbone_cls = MistralBackbone def convert_backbone_config(transformers_config): @@ -34,7 +33,7 @@ def convert_backbone_config(transformers_config): } -def convert_weights(backbone, loader): +def convert_weights(backbone, loader, transformers_config): # Embeddings loader.port_weight( keras_variable=backbone.token_embedding.embeddings, @@ -125,19 +124,6 @@ def convert_weights(backbone, loader): hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16), ) - return backbone - - -def load_mistral_backbone(cls, preset, load_weights): - transformers_config = load_config(preset, HF_CONFIG_FILE) - keras_config = convert_backbone_config(transformers_config) - backbone = cls(**keras_config) - if load_weights: - jax_memory_cleanup(backbone) - with SafetensorLoader(preset) as loader: - convert_weights(backbone, loader) - return backbone - -def load_mistral_tokenizer(cls, preset): - return cls(get_file(preset, "tokenizer.model")) +def convert_tokenizer(cls, preset, **kwargs): + return cls(get_file(preset, "tokenizer.model"), **kwargs) diff --git a/keras_nlp/src/utils/transformers/convert_mistral_test.py b/keras_nlp/src/utils/transformers/convert_mistral_test.py index 82ac9eccc4..56982faf3b 100644 --- a/keras_nlp/src/utils/transformers/convert_mistral_test.py +++ b/keras_nlp/src/utils/transformers/convert_mistral_test.py @@ -13,6 +13,9 @@ # limitations under the License. import pytest +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.causal_lm import CausalLM +from keras_nlp.src.models.mistral.mistral_backbone import MistralBackbone from keras_nlp.src.models.mistral.mistral_causal_lm import MistralCausalLM from keras_nlp.src.tests.test_case import TestCase @@ -24,4 +27,17 @@ def test_convert_tiny_preset(self): prompt = "What is your favorite condiment?" model.generate([prompt], max_length=15) + @pytest.mark.large + def test_class_detection(self): + model = CausalLM.from_preset( + "hf://cosmo3769/tiny-mistral-test", + load_weights=False, + ) + self.assertIsInstance(model, MistralCausalLM) + model = Backbone.from_preset( + "hf://cosmo3769/tiny-mistral-test", + load_weights=False, + ) + self.assertIsInstance(model, MistralBackbone) + # TODO: compare numerics with huggingface model diff --git a/keras_nlp/src/utils/transformers/convert_pali_gemma.py b/keras_nlp/src/utils/transformers/convert_pali_gemma.py index 3382794041..95f47e0102 100644 --- a/keras_nlp/src/utils/transformers/convert_pali_gemma.py +++ b/keras_nlp/src/utils/transformers/convert_pali_gemma.py @@ -13,11 +13,12 @@ # limitations under the License. import numpy as np -from keras_nlp.src.utils.preset_utils import HF_CONFIG_FILE +from keras_nlp.src.models.pali_gemma.pali_gemma_backbone import ( + PaliGemmaBackbone, +) from keras_nlp.src.utils.preset_utils import get_file -from keras_nlp.src.utils.preset_utils import jax_memory_cleanup -from keras_nlp.src.utils.preset_utils import load_config -from keras_nlp.src.utils.transformers.safetensor_utils import SafetensorLoader + +backbone_cls = PaliGemmaBackbone def convert_backbone_config(transformers_config): @@ -275,29 +276,6 @@ def convert_weights(backbone, loader, transformers_config): hook_fn=lambda hf_tensor, keras_shape: hf_tensor[: keras_shape[0]], ) - return backbone - - -def load_pali_gemma_backbone(cls, preset, load_weights): - transformers_config = load_config(preset, HF_CONFIG_FILE) - keras_config = convert_backbone_config(transformers_config) - backbone = cls(**keras_config) - if load_weights: - jax_memory_cleanup(backbone) - with SafetensorLoader(preset) as loader: - convert_weights(backbone, loader, transformers_config) - return backbone - - -def load_pali_gemma_tokenizer(cls, preset): - """ - Load the Gemma tokenizer. - - Args: - cls (class): Tokenizer class. - preset (str): Preset configuration name. - Returns: - tokenizer: Initialized tokenizer. - """ - return cls(get_file(preset, "tokenizer.model")) +def convert_tokenizer(cls, preset, **kwargs): + return cls(get_file(preset, "tokenizer.model"), **kwargs) diff --git a/keras_nlp/src/utils/transformers/convert_pali_gemma_test.py b/keras_nlp/src/utils/transformers/convert_pali_gemma_test.py index 994fe20d7f..dbd5405f53 100644 --- a/keras_nlp/src/utils/transformers/convert_pali_gemma_test.py +++ b/keras_nlp/src/utils/transformers/convert_pali_gemma_test.py @@ -14,6 +14,11 @@ import numpy as np import pytest +from keras_nlp.src.models.backbone import Backbone +from keras_nlp.src.models.causal_lm import CausalLM +from keras_nlp.src.models.pali_gemma.pali_gemma_backbone import ( + PaliGemmaBackbone, +) from keras_nlp.src.models.pali_gemma.pali_gemma_causal_lm import ( PaliGemmaCausalLM, ) @@ -30,4 +35,17 @@ def test_convert_tiny_preset(self): prompt = "describe the image " model.generate({"images": image, "prompts": prompt}, max_length=15) + @pytest.mark.large + def test_class_detection(self): + model = CausalLM.from_preset( + "hf://ariG23498/tiny-pali-gemma-test", + load_weights=False, + ) + self.assertIsInstance(model, PaliGemmaCausalLM) + model = Backbone.from_preset( + "hf://ariG23498/tiny-pali-gemma-test", + load_weights=False, + ) + self.assertIsInstance(model, PaliGemmaBackbone) + # TODO: compare numerics with huggingface model diff --git a/keras_nlp/src/utils/transformers/preset_loader.py b/keras_nlp/src/utils/transformers/preset_loader.py new file mode 100644 index 0000000000..1a1ce928ba --- /dev/null +++ b/keras_nlp/src/utils/transformers/preset_loader.py @@ -0,0 +1,73 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert huggingface models to KerasNLP.""" + + +from keras_nlp.src.utils.preset_utils import PresetLoader +from keras_nlp.src.utils.preset_utils import jax_memory_cleanup +from keras_nlp.src.utils.transformers import convert_albert +from keras_nlp.src.utils.transformers import convert_bart +from keras_nlp.src.utils.transformers import convert_bert +from keras_nlp.src.utils.transformers import convert_distilbert +from keras_nlp.src.utils.transformers import convert_gemma +from keras_nlp.src.utils.transformers import convert_gpt2 +from keras_nlp.src.utils.transformers import convert_llama3 +from keras_nlp.src.utils.transformers import convert_mistral +from keras_nlp.src.utils.transformers import convert_pali_gemma +from keras_nlp.src.utils.transformers.safetensor_utils import SafetensorLoader + + +class TransformersPresetLoader(PresetLoader): + def __init__(self, preset, config): + super().__init__(preset, config) + model_type = self.config["model_type"] + if model_type == "albert": + self.converter = convert_albert + elif model_type == "bart": + self.converter = convert_bart + elif model_type == "bert": + self.converter = convert_bert + elif model_type == "distilbert": + self.converter = convert_distilbert + elif model_type == "gemma" or model_type == "gemma2": + self.converter = convert_gemma + elif model_type == "gpt2": + self.converter = convert_gpt2 + elif model_type == "llama": + # TODO: handle other llama versions. + self.converter = convert_llama3 + elif model_type == "mistral": + self.converter = convert_mistral + elif model_type == "paligemma": + self.converter = convert_pali_gemma + else: + raise ValueError( + "KerasNLP has no converter for huggingface/transformers models " + f"with model type `'{model_type}'`." + ) + + def check_backbone_class(self): + return self.converter.backbone_cls + + def load_backbone(self, cls, load_weights, **kwargs): + keras_config = self.converter.convert_backbone_config(self.config) + backbone = cls(**{**keras_config, **kwargs}) + if load_weights: + jax_memory_cleanup(backbone) + with SafetensorLoader(self.preset) as loader: + self.converter.convert_weights(backbone, loader, self.config) + return backbone + + def load_tokenizer(self, cls, **kwargs): + return self.converter.convert_tokenizer(cls, self.preset, **kwargs) diff --git a/keras_nlp/src/utils/transformers/safetensor_utils.py b/keras_nlp/src/utils/transformers/safetensor_utils.py index 60451a981f..1305bb77d7 100644 --- a/keras_nlp/src/utils/transformers/safetensor_utils.py +++ b/keras_nlp/src/utils/transformers/safetensor_utils.py @@ -17,7 +17,7 @@ from keras_nlp.src.utils.preset_utils import SAFETENSOR_FILE from keras_nlp.src.utils.preset_utils import check_file_exists from keras_nlp.src.utils.preset_utils import get_file -from keras_nlp.src.utils.preset_utils import load_config +from keras_nlp.src.utils.preset_utils import load_json try: import safetensors @@ -38,7 +38,7 @@ def __init__(self, preset): self.preset = preset if check_file_exists(preset, SAFETENSOR_CONFIG_FILE): - self.safetensor_config = load_config(preset, SAFETENSOR_CONFIG_FILE) + self.safetensor_config = load_json(preset, SAFETENSOR_CONFIG_FILE) else: self.safetensor_config = None self.safetensor_files = {}