diff --git a/README.md b/README.md index 7a33d1f702e2a7..35f5312148a28f 100644 --- a/README.md +++ b/README.md @@ -240,6 +240,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[Funnel Transformer](https://huggingface.co/transformers/model_doc/funnel.html)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. 1. **[GPT](https://huggingface.co/transformers/model_doc/gpt.html)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. 1. **[GPT-2](https://huggingface.co/transformers/model_doc/gpt2.html)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. +1. **[GPT-J](https://huggingface.co/transformers/model_doc/gptj.html)** (from EleutherAI) released in the repository [kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax/) by Ben Wang and Aran Komatsuzaki. 1. **[GPT Neo](https://huggingface.co/transformers/model_doc/gpt_neo.html)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. 1. **[Hubert](https://huggingface.co/transformers/model_doc/hubert.html)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed. 1. **[I-BERT](https://huggingface.co/transformers/model_doc/ibert.html)** (from Berkeley) released with the paper [I-BERT: Integer-only BERT Quantization](https://arxiv.org/abs/2101.01321) by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer diff --git a/docs/source/index.rst b/docs/source/index.rst index 9d91b907d3626f..48e0fdbe634ac1 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -191,116 +191,118 @@ Supported models 30. :doc:`GPT-2 ` (from OpenAI) released with the paper `Language Models are Unsupervised Multitask Learners `__ by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. -31. :doc:`GPT Neo ` (from EleutherAI) released in the repository `EleutherAI/gpt-neo +31. :doc:`GPT-J ` (from EleutherAI) released in the repository `kingoflolz/mesh-transformer-jax + `__ by Ben Wang and Aran Komatsuzaki. +32. :doc:`GPT Neo ` (from EleutherAI) released in the repository `EleutherAI/gpt-neo `__ by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. -32. :doc:`Hubert ` (from Facebook) released with the paper `HuBERT: Self-Supervised Speech +33. :doc:`Hubert ` (from Facebook) released with the paper `HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units `__ by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed. -33. :doc:`I-BERT ` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization +34. :doc:`I-BERT ` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization `__ by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer -34. :doc:`LayoutLM ` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training +35. :doc:`LayoutLM ` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training of Text and Layout for Document Image Understanding `__ by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. -35. :doc:`LayoutLMv2 ` (from Microsoft Research Asia) released with the paper `LayoutLMv2: +36. :doc:`LayoutLMv2 ` (from Microsoft Research Asia) released with the paper `LayoutLMv2: Multi-modal Pre-training for Visually-Rich Document Understanding `__ by Yang Xu, Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min Zhang, Lidong Zhou. -36. :doc:`LayoutXLM ` (from Microsoft Research Asia) released with the paper `LayoutXLM: +37. :doc:`LayoutXLM ` (from Microsoft Research Asia) released with the paper `LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding `__ by Yiheng Xu, Tengchao Lv, Lei Cui, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Furu Wei. -37. :doc:`LED ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer +38. :doc:`LED ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -38. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document +39. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -39. :doc:`LUKE ` (from Studio Ousia) released with the paper `LUKE: Deep Contextualized Entity +40. :doc:`LUKE ` (from Studio Ousia) released with the paper `LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention `__ by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto. -40. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality +41. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering `__ by Hao Tan and Mohit Bansal. -41. :doc:`M2M100 ` (from Facebook) released with the paper `Beyond English-Centric Multilingual +42. :doc:`M2M100 ` (from Facebook) released with the paper `Beyond English-Centric Multilingual Machine Translation `__ by by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. -42. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by +43. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by Jörg Tiedemann. The `Marian Framework `__ is being developed by the Microsoft Translator Team. -43. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for +44. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for Neural Machine Translation `__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. -44. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible +45. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible Multilingual Pretraining and Finetuning `__ by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. -45. :doc:`Megatron-BERT ` (from NVIDIA) released with the paper `Megatron-LM: Training +46. :doc:`Megatron-BERT ` (from NVIDIA) released with the paper `Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism `__ by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. -46. :doc:`Megatron-GPT2 ` (from NVIDIA) released with the paper `Megatron-LM: Training +47. :doc:`Megatron-GPT2 ` (from NVIDIA) released with the paper `Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism `__ by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. -47. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted +48. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted Pre-training for Language Understanding `__ by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. -48. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained +49. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained text-to-text transformer `__ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. -49. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted +50. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization `__> by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. -50. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting +51. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -51. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient +52. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient Transformer `__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. -52. :doc:`RemBERT ` (from Google Research) released with the paper `Rethinking embedding coupling in +53. :doc:`RemBERT ` (from Google Research) released with the paper `Rethinking embedding coupling in pre-trained language models `__ by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. -53. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT +54. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT Pretraining Approach `__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. -54. :doc:`RoFormer ` (from ZhuiyiTechnology), released together with the paper a `RoFormer: +55. :doc:`RoFormer ` (from ZhuiyiTechnology), released together with the paper a `RoFormer: Enhanced Transformer with Rotary Position Embedding `__ by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. -55. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper +56. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper `fairseq S2T: Fast Speech-to-Text Modeling with fairseq `__ by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino. -56. :doc:`Splinter ` (from Tel Aviv University), released together with the paper `Few-Shot +57. :doc:`Splinter ` (from Tel Aviv University), released together with the paper `Few-Shot Question Answering by Pretraining Span Selection `__ by Ori Ram, Yuval Kirstain, Jonathan Berant, Amir Globerson, Omer Levy. -57. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP +58. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP about efficient neural networks? `__ by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. -58. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a +59. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `__ by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. -59. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via +60. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via Pre-training `__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. -60. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: +61. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context `__ by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. -61. :doc:`Vision Transformer (ViT) ` (from Google AI) released with the paper `An Image is Worth 16x16 +62. :doc:`Vision Transformer (ViT) ` (from Google AI) released with the paper `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `__ by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. -62. :doc:`VisualBERT ` (from UCLA NLP) released with the paper `VisualBERT: A Simple and +63. :doc:`VisualBERT ` (from UCLA NLP) released with the paper `VisualBERT: A Simple and Performant Baseline for Vision and Language `__ by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang. -63. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for +64. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations `__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. -64. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model +65. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model Pretraining `__ by Guillaume Lample and Alexis Conneau. -65. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: +66. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -66. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised +67. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised Cross-lingual Representation Learning at Scale `__ by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. -67. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive +68. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive Pretraining for Language Understanding `__ by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. -68. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised +69. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised Cross-Lingual Representation Learning For Speech Recognition `__ by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli. @@ -372,6 +374,8 @@ Flax), PyTorch, and/or TensorFlow. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | GPT Neo | ❌ | ❌ | ✅ | ❌ | ✅ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| GPT-J | ❌ | ❌ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | Hubert | ❌ | ❌ | ✅ | ✅ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ | @@ -574,6 +578,7 @@ Flax), PyTorch, and/or TensorFlow. model_doc/mt5 model_doc/gpt model_doc/gpt2 + model_doc/gptj model_doc/gpt_neo model_doc/hubert model_doc/pegasus diff --git a/docs/source/model_doc/gptj.rst b/docs/source/model_doc/gptj.rst new file mode 100644 index 00000000000000..1c296d453a1b76 --- /dev/null +++ b/docs/source/model_doc/gptj.rst @@ -0,0 +1,102 @@ +.. + Copyright 2021 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +GPT-J +----------------------------------------------------------------------------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The GPT-J model was released in the `kingoflolz/mesh-transformer-jax +`__ repository by Ben Wang and Aran Komatsuzaki. It is a GPT-2-like +causal language model trained on `the Pile `__ dataset. + +This model was contributed by `Stella Biderman `__. + +Tips: + +- Running [GPT-J](https://huggingface.co/EleutherAI/gpt-j-6B) in float32 precision on GPU requires at least 24 GB of + RAM. On GPUs with less than 24 GB RAM, one should therefore load the model in half-precision: + +.. code-block:: + + >>> from transformers import GPTJForCausalLM + >>> import torch + + >>> model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.float16) + +Generation +_______________________________________________________________________________________________________________________ + +The :meth:`~transformers.generation_utils.GenerationMixin.generate` method can be used to generate text using GPT-J +model. + +.. code-block:: + + >>> from transformers import AutoModelForCausalLM, AutoTokenizer + >>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B") + >>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") + + >>> prompt = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \ + ... "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \ + ... "researchers was the fact that the unicorns spoke perfect English." + + >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids + + >>> gen_tokens = model.generate(input_ids, do_sample=True, temperature=0.9, max_length=100,) + >>> gen_text = tokenizer.batch_decode(gen_tokens)[0] + +...or in float16 precision: + +.. code-block:: + + >>> from transformers import GPTJForCausalLM, AutoTokenizer + >>> import torch + + >>> model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.float16) + >>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") + + >>> prompt = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \ + ... "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \ + ... "researchers was the fact that the unicorns spoke perfect English." + + >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids + + >>> gen_tokens = model.generate(input_ids, do_sample=True, temperature=0.9, max_length=100,) + >>> gen_text = tokenizer.batch_decode(gen_tokens)[0] + + +GPTJConfig +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.GPTJConfig + :members: + +GPTJModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.GPTJModel + :members: forward + + +GPTJForCausalLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.GPTJForCausalLM + :members: forward + + +GPTJForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.GPTJForSequenceClassification + :members: forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f636e2e8080b66..62c2b8fa741712 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -213,6 +213,7 @@ "models.funnel": ["FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP", "FunnelConfig", "FunnelTokenizer"], "models.gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2Tokenizer"], "models.gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"], + "models.gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig"], "models.herbert": ["HerbertTokenizer"], "models.hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"], "models.ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"], @@ -824,6 +825,15 @@ "load_tf_weights_in_gpt_neo", ] ) + _import_structure["models.gptj"].extend( + [ + "GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST", + "GPTJForCausalLM", + "GPTJForSequenceClassification", + "GPTJModel", + "GPTJPreTrainedModel", + ] + ) _import_structure["models.hubert"].extend( [ "HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -1966,6 +1976,7 @@ from .models.funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig, FunnelTokenizer from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer from .models.gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig + from .models.gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig from .models.herbert import HerbertTokenizer from .models.hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig from .models.ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig @@ -2486,6 +2497,13 @@ GPTNeoPreTrainedModel, load_tf_weights_in_gpt_neo, ) + from .models.gptj import ( + GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST, + GPTJForCausalLM, + GPTJForSequenceClassification, + GPTJModel, + GPTJPreTrainedModel, + ) from .models.hubert import ( HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST, HubertForCTC, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 771b6faf5573ac..75fcbfa6376db7 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -50,6 +50,7 @@ funnel, gpt2, gpt_neo, + gptj, herbert, hubert, ibert, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index f61854dd5666f3..efdd964324a6f9 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -26,6 +26,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( [ # Add configs here + ("gptj", "GPTJConfig"), ("layoutlmv2", "LayoutLMv2Config"), ("beit", "BeitConfig"), ("rembert", "RemBertConfig"), @@ -96,6 +97,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict( [ # Add archive maps here + ("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -158,6 +160,7 @@ MODEL_NAMES_MAPPING = OrderedDict( [ # Add full (and cased) model names here + ("gptj", "GPT-J"), ("beit", "BeiT"), ("rembert", "RemBERT"), ("layoutlmv2", "LayoutLMv2"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 26680c0cb27010..3a3cb9dc6f6f7c 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -28,6 +28,7 @@ MODEL_MAPPING_NAMES = OrderedDict( [ # Base model mapping + ("gptj", "GPTJModel"), ("layoutlmv2", "LayoutLMv2Model"), ("beit", "BeitModel"), ("rembert", "RemBertModel"), @@ -135,6 +136,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( [ # Model with LM heads mapping + ("gptj", "GPTJForCausalLM"), ("rembert", "RemBertForMaskedLM"), ("roformer", "RoFormerForMaskedLM"), ("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"), @@ -183,6 +185,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Causal LM mapping + ("gptj", "GPTJForCausalLM"), ("rembert", "RemBertForCausalLM"), ("roformer", "RoFormerForCausalLM"), ("bigbird_pegasus", "BigBirdPegasusForCausalLM"), @@ -286,6 +289,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Sequence Classification mapping + ("gptj", "GPTJForSequenceClassification"), ("layoutlmv2", "LayoutLMv2ForSequenceClassification"), ("rembert", "RemBertForSequenceClassification"), ("canine", "CanineForSequenceClassification"), diff --git a/src/transformers/models/gptj/__init__.py b/src/transformers/models/gptj/__init__.py new file mode 100644 index 00000000000000..043f926af53770 --- /dev/null +++ b/src/transformers/models/gptj/__init__.py @@ -0,0 +1,52 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_torch_available + + +_import_structure = { + "configuration_gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig"], +} + +if is_torch_available(): + _import_structure["modeling_gptj"] = [ + "GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST", + "GPTJForCausalLM", + "GPTJForSequenceClassification", + "GPTJModel", + "GPTJPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig + + if is_torch_available(): + from .modeling_gptj import ( + GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST, + GPTJForCausalLM, + GPTJForSequenceClassification, + GPTJModel, + GPTJPreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/gptj/configuration_gptj.py b/src/transformers/models/gptj/configuration_gptj.py new file mode 100644 index 00000000000000..227a73ab9c24bf --- /dev/null +++ b/src/transformers/models/gptj/configuration_gptj.py @@ -0,0 +1,151 @@ +# coding=utf-8 +# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" GPT-J model configuration """ + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "EleutherAI/gpt-j-6B": "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/main/config.json", + # See all GPT-J models at https://huggingface.co/models?filter=gpt_j +} + + +class GPTJConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.GPTJModel`. It is used to + instantiate a GPT-J model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the GPT-J `gpt-j-6B + `__ architecture. Configuration objects inherit from + :class:`~transformers.PretrainedConfig` and can be used to control the model outputs. Read the documentation from + :class:`~transformers.PretrainedConfig` for more information. + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 50400): + Vocabulary size of the GPT-J model. Defines the number of different tokens that can be represented by the + :obj:`inputs_ids` passed when calling :class:`~transformers.GPTJModel`. + n_positions (:obj:`int`, `optional`, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_ctx (:obj:`int`, `optional`, defaults to 2048): + Dimensionality of the causal mask (usually same as n_positions). + n_embd (:obj:`int`, `optional`, defaults to 4096): + Dimensionality of the embeddings and hidden states. + n_layer (:obj:`int`, `optional`, defaults to 28): + Number of hidden layers in the Transformer encoder. + n_head (:obj:`int`, `optional`, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + rotary_dim (:obj:`int`, `optional`, defaults to 64): + Number of dimensions in the embedding that Rotary Position Embedding is applied to. + n_inner (:obj:`int`, `optional`, defaults to None): + Dimensionality of the inner feed-forward layers. :obj:`None` will set it to 4 times n_embd + activation_function (:obj:`str`, `optional`, defaults to :obj:`"gelu_new"`): + Activation function, to be selected in the list :obj:`["relu", "silu", "gelu", "tanh", "gelu_new"]`. + resid_pdrop (:obj:`float`, `optional`, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (:obj:`int`, `optional`, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (:obj:`float`, `optional`, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + scale_attn_weights (:obj:`bool`, `optional`, defaults to :obj:`True`): + Scale attention weights by dividing by sqrt(hidden_size). + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values attentions (not used by all models). + + Example:: + + >>> from transformers import GPTJModel, GPTJConfig + + >>> # Initializing a GPT-J 6B configuration + >>> configuration = GPTJConfig() + + >>> # Initializing a model from the configuration + >>> model = GPTJModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + model_type = "gptj" + + def __init__( + self, + vocab_size=50400, + n_positions=2048, + n_ctx=2048, + n_embd=4096, + n_layer=28, + n_head=16, + rotary_dim=64, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + scale_attn_weights=True, + gradient_checkpointing=False, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + **kwargs + ): + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.n_ctx = n_ctx + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.rotary_dim = rotary_dim + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.gradient_checkpointing = gradient_checkpointing + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + @property + def max_position_embeddings(self): + return self.n_positions + + @property + def hidden_size(self): + return self.n_embd + + @property + def num_attention_heads(self): + return self.n_head + + @property + def num_hidden_layers(self): + return self.n_layer diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py new file mode 100755 index 00000000000000..da44293a1ad9e7 --- /dev/null +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -0,0 +1,948 @@ +# coding=utf-8 +# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch GPT-J model. """ + +from typing import Tuple + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from ...utils.model_parallel_utils import assert_device_map, get_device_map +from .configuration_gptj import GPTJConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "EleutherAI/gpt-j-6b" +_CONFIG_FOR_DOC = "GPTJConfig" +_TOKENIZER_FOR_DOC = "GPT2Tokenizer" + +GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "EleutherAI/gpt-j-6B", + # See all GPT-J models at https://huggingface.co/models?filter=gptj +] + + +def fixed_pos_embedding(x, seq_dim=1, seq_len=None): + dim = x.shape[-1] + if seq_len is None: + seq_len = x.shape[seq_dim] + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) + sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq).to(x.device).float() + return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) + + +def rotate_every_two(x): + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + x = torch.stack((-x2, x1), axis=-1) + return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') + + +def apply_rotary_pos_emb(x, sincos, offset=0): + sin, cos = map(lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(2, 3), sincos) + # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) + return (x * cos) + (rotate_every_two(x) * sin) + + +class GPTJAttention(nn.Module): + def __init__(self, config): + super().__init__() + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + ) + self.register_buffer("masked_bias", torch.tensor(-1e9)) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.embed_dim = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_attention_heads + if self.head_dim * self.num_attention_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})." + ) + self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.rotary_dim = None + if config.rotary_dim is not None: + self.rotary_dim = config.rotary_dim + + def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary): + """ + Splits n_ctx dim into attn_head_size and num_attention_heads + """ + new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) + tensor = tensor.view(*new_shape) + if rotary: + return tensor + if len(tensor.shape) == 5: + return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features) + elif len(tensor.shape) == 4: + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + else: + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + + def _merge_heads(self, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into n_ctx + """ + if len(tensor.shape) == 5: + tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() + elif len(tensor.shape) == 4: + tensor = tensor.permute(0, 2, 1, 3).contiguous() + else: + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) + return tensor.view(new_shape) + + def _attn( + self, + query, + key, + value, + attention_mask=None, + head_mask=None, + ): + + # compute causal mask from causal mask buffer + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + + # Keep the attention weights computation in fp32 to avoid overflow issues + query = query.to(torch.float32) + key = key.to(torch.float32) + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + + attn_weights = attn_weights / self.scale_attn + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.Softmax(dim=-1)(attn_weights) + attn_weights = attn_weights.to(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def forward( + self, + hidden_states, + attention_mask=None, + layer_past=None, + head_mask=None, + use_cache=False, + output_attentions=False, + ): + + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self._split_heads(query, self.num_attention_heads, self.head_dim, True) + key = self._split_heads(key, self.num_attention_heads, self.head_dim, True) + value = self._split_heads(value, self.num_attention_heads, self.head_dim, False) + + seq_len = key.shape[1] + offset = 0 + + if layer_past is not None: + offset = layer_past[0].shape[-2] + seq_len += offset + + if self.rotary_dim is not None: + k_rot = key[:, :, :, : self.rotary_dim] + k_pass = key[:, :, :, self.rotary_dim :] + + q_rot = query[:, :, :, : self.rotary_dim] + q_pass = query[:, :, :, self.rotary_dim :] + + sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len) + k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) + q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) + + key = torch.cat([k_rot, k_pass], dim=-1) + query = torch.cat([q_rot, q_pass], dim=-1) + else: + sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) + key = apply_rotary_pos_emb(key, sincos, offset=offset) + query = apply_rotary_pos_emb(query, sincos, offset=offset) + + key = key.permute(0, 2, 1, 3) + query = query.permute(0, 2, 1, 3) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + # compute self-attention: V x Softmax(QK^T) + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +class GPTJMLP(nn.Module): + def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim + super().__init__() + embed_dim = config.n_embd + + self.fc_in = nn.Linear(embed_dim, intermediate_size) + self.fc_out = nn.Linear(intermediate_size, embed_dim) + + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states): + hidden_states = self.fc_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc_out(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class GPTJBlock(nn.Module): + def __init__(self, config): + super().__init__() + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd + self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.attn = GPTJAttention(config) + self.mlp = GPTJMLP(inner_dim, config) + + def forward( + self, + hidden_states, + layer_past=None, + attention_mask=None, + head_mask=None, + use_cache=False, + output_attentions=False, + ): + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + + feed_forward_hidden_states = self.mlp(hidden_states) + hidden_states = attn_output + feed_forward_hidden_states + residual + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions) + + +class GPTJPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPTJConfig + base_model_prefix = "transformer" + is_parallelizable = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear,)): + # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +GPTJ_START_DOCSTRING = r""" + This model is a PyTorch `torch.nn.Module `_ sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config (:class:`~transformers.GPTJConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +GPTJ_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`transformers.GPTJTokenizer`. See + :func:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.n_positions - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_attention_heads,)` or :obj:`(n_layer, num_attention_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, n_ctx)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. Uses a device map to distribute + attention modules of the model across several devices. If no device map is given, it will evenly distribute blocks + across all devices. + + Args: + device_map (:obj:`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the GPT-J models have the + following number of attention modules: + + - gpt-j-6B: 28 + + Example:: + # Here is an example of a device map on a machine with 4 GPUs using gpt-j-6B, which has a total of 28 attention modules: + model = GPTJForCausalLM.from_pretrained('EleutherAI/gpt-j-6B') + device_map = {0: [0, 1, 2, 3, 4, 5, 6], + 1: [7, 8, 9, 10, 11, 12, 13], + 2: [14, 15, 16, 17, 18, 19, 20], + 3: [21, 22, 23, 24, 25, 26, 27]} + model.parallelize(device_map) +""" + +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to CPU from a model parallel state. + + Example:: + # On a 4 GPU machine with gpt-j-6B: + model = GPTJForCausalLM.from_pretrained('EleutherAI/gpt-j-6B') + device_map = {0: [0, 1, 2, 3, 4, 5, 6], + 1: [7, 8, 9, 10, 11, 12, 13], + 2: [14, 15, 16, 17, 18, 19, 20], + 3: [21, 22, 23, 24, 25, 26, 27]} + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() +""" + + +@add_start_docstrings( + "The bare GPT-J Model transformer outputting raw hidden-states without any specific head on top.", + GPTJ_START_DOCSTRING, +) +class GPTJModel(GPTJPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.n_embd + self.vocab_size = config.vocab_size + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPTJBlock(config) for _ in range(config.n_layer)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads) + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # Attention mask. + if attention_mask is not None: + assert batch_size > 0, "batch_size has to be defined and > 0" + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x num_attention_heads x N x N + # head_mask has shape n_layer x batch x num_attention_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + hidden_states = inputs_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(*output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@add_start_docstrings( + """ + The GPT-J Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPTJ_START_DOCSTRING, +) +class GPTJForCausalLM(GPTJPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPTJModel(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size) + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return None + + def set_output_embeddings(self, new_embeddings): + return + + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # only last token for inputs_ids if past is defined in kwargs + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to + ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + # make sure sampling in fp16 works correctly and + # compute loss in fp32 to match with mesh-tf version + # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179 + lm_logits = self.lm_head(hidden_states).to(torch.float32) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + + +@add_start_docstrings( + """ + The GPT-J Model transformer with a sequence classification head on top (linear layer). + + :class:`~transformers.GPTJForSequenceClassification` uses the last token in order to do the classification, as + other causal models (e.g. GPT, GPT-2, GPT-Neo) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + :obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each + row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot + guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take + the last value in each row of the batch). + """, + GPTJ_START_DOCSTRING, +) +class GPTJForSequenceClassification(GPTJPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPTJModel(config) + self.score = nn.Linear(config.n_ctx, self.num_labels, bias=False) + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[range(batch_size), sequence_lengths] + + loss = None + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 291b7c56445e2c..624042992f6df8 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -32,6 +32,22 @@ class TextGenerationPipeline(Pipeline): begging for his blessing. """ + ALLOWED_MODELS = [ + "XLNetLMHeadModel", + "TransfoXLLMHeadModel", + "ReformerModelWithLMHead", + "GPT2LMHeadModel", + "GPTJForCausalLM", + "GPTNeoForCausalLM", + "OpenAIGPTLMHeadModel", + "CTRLLMHeadModel", + "TFXLNetLMHeadModel", + "TFTransfoXLLMHeadModel", + "TFGPT2LMHeadModel", + "TFOpenAIGPTLMHeadModel", + "TFCTRLLMHeadModel", + ] + def __init__(self, *args, return_full_text=True, **kwargs): super().__init__(*args, **kwargs) self.check_model_type( diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 7241da720dc654..86de6778ca1fce 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1808,6 +1808,45 @@ def load_tf_weights_in_gpt_neo(*args, **kwargs): requires_backends(load_tf_weights_in_gpt_neo, ["torch"]) +GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class GPTJForCausalLM: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class GPTJForSequenceClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class GPTJModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class GPTJPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/test_modeling_gptj.py b/tests/test_modeling_gptj.py new file mode 100644 index 00000000000000..5739aed5a1f76b --- /dev/null +++ b/tests/test_modeling_gptj.py @@ -0,0 +1,560 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import datetime +import unittest + +from transformers import GPTJConfig, is_torch_available +from transformers.testing_utils import require_torch, slow, torch_device + +from .test_configuration_common import ConfigTester +from .test_modeling_common import floats_tensor, ids_tensor, random_attention_mask + + +if is_torch_available(): + import torch + + from transformers import ( + GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST, + AutoTokenizer, + GPTJForCausalLM, + GPTJForSequenceClassification, + GPTJModel, + ) + + +class GPTJModelTester: + def __init__( + self, + parent, + batch_size=14, + seq_length=7, + is_training=True, + use_token_type_ids=True, + use_input_mask=True, + use_labels=True, + use_mc_token_ids=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_token_type_ids = use_token_type_ids + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.use_mc_token_ids = use_mc_token_ids + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = None + self.bos_token_id = vocab_size - 1 + self.eos_token_id = vocab_size - 1 + self.pad_token_id = vocab_size - 1 + + def get_large_model_config(self): + return GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B") + + def prepare_config_and_inputs(self, gradient_checkpointing=False): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + mc_token_ids = None + if self.use_mc_token_ids: + mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config(gradient_checkpointing=gradient_checkpointing) + + head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) + + return ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) + + def get_config(self, gradient_checkpointing=False): + return GPTJConfig( + vocab_size=self.vocab_size, + n_embd=self.hidden_size, + n_layer=self.num_hidden_layers, + n_head=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + n_positions=self.max_position_embeddings, + n_ctx=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range, + use_cache=not gradient_checkpointing, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + gradient_checkpointing=gradient_checkpointing, + ) + + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + + def create_and_check_gptj_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + model = GPTJModel(config=config) + model.to(torch_device) + model.eval() + + result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask) + result = model(input_ids, token_type_ids=token_type_ids) + result = model(input_ids) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual(len(result.past_key_values), config.n_layer) + + def create_and_check_gptj_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + model = GPTJModel(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids, token_type_ids=token_type_ids) + outputs_no_past = model(input_ids, token_type_ids=token_type_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + output, past = outputs.to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + next_token_types = ids_tensor([self.batch_size, 1], self.type_vocab_size) + + # append to next input_ids and token_type_ids + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_token_type_ids = torch.cat([token_type_ids, next_token_types], dim=-1) + + output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"] + output_from_past = model(next_tokens, token_type_ids=next_token_types, past_key_values=past)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_gptj_model_attention_mask_past( + self, config, input_ids, input_mask, head_mask, token_type_ids, *args + ): + model = GPTJModel(config=config) + model.to(torch_device) + model.eval() + + # create attention mask + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + half_seq_length = self.seq_length // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + output, past = model(input_ids, attention_mask=attn_mask).to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past, attention_mask=attn_mask)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_gptj_model_past_large_inputs( + self, config, input_ids, input_mask, head_mask, token_type_ids, *args + ): + model = GPTJModel(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + outputs = model(input_ids, token_type_ids=token_type_ids, attention_mask=input_mask, use_cache=True) + + output, past = outputs.to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_token_types = ids_tensor([self.batch_size, 3], self.type_vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and token_type_ids + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_token_type_ids = torch.cat([token_type_ids, next_token_types], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, token_type_ids=next_token_type_ids, attention_mask=next_attention_mask + )["last_hidden_state"] + output_from_past = model( + next_tokens, token_type_ids=next_token_types, attention_mask=next_attention_mask, past_key_values=past + )["last_hidden_state"] + self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1]) + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + model = GPTJForCausalLM(config) + model.to(torch_device) + model.eval() + + result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + model = GPTJForCausalLM(config) + model.to(torch_device) + + result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + result.loss.backward() + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + + ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "head_mask": head_mask} + + return config, inputs_dict + + +@require_torch +class GPTJModelTest(unittest.TestCase): + + all_model_classes = (GPTJModel, GPTJForCausalLM, GPTJForSequenceClassification) if is_torch_available() else () + all_generative_model_classes = (GPTJForCausalLM,) if is_torch_available() else () + fx_ready_model_classes = all_model_classes + test_pruning = False + test_missing_keys = False + test_model_parallel = False + + # special case for DoubleHeads model + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + return inputs_dict + + def setUp(self): + self.model_tester = GPTJModelTester(self) + self.config_tester = ConfigTester(self, config_class=GPTJConfig, n_embd=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_gptj_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gptj_model(*config_and_inputs) + + def test_gptj_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gptj_model_past(*config_and_inputs) + + def test_gptj_model_att_mask_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gptj_model_attention_mask_past(*config_and_inputs) + + def test_gptj_model_past_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gptj_model_past_large_inputs(*config_and_inputs) + + def test_gptj_lm_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_lm_head_model(*config_and_inputs) + + def test_gptj_gradient_checkpointing(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) + self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) + + @slow + def test_batch_generation(self): + model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B") + model.to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") + + tokenizer.padding_side = "left" + + # Define PAD Token = EOS Token = 50256 + tokenizer.pad_token = tokenizer.eos_token + model.config.pad_token_id = model.config.eos_token_id + + # use different length sentences to test batching + sentences = [ + "Hello, my dog is a little", + "Today, I", + ] + + inputs = tokenizer(sentences, return_tensors="pt", padding=True) + input_ids = inputs["input_ids"].to(torch_device) + token_type_ids = torch.cat( + [ + input_ids.new_full((input_ids.shape[0], input_ids.shape[1] - 1), 0), + input_ids.new_full((input_ids.shape[0], 1), 500), + ], + dim=-1, + ) + + outputs = model.generate( + input_ids=input_ids, + attention_mask=inputs["attention_mask"].to(torch_device), + ) + + outputs_tt = model.generate( + input_ids=input_ids, + attention_mask=inputs["attention_mask"].to(torch_device), + token_type_ids=token_type_ids, + ) + + inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device) + output_non_padded = model.generate(input_ids=inputs_non_padded) + + num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item() + inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device) + output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings) + + batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) + batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True) + non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True) + padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True) + + expected_output_sentence = [ + "Hello, my dog is a little over a year old and has been diagnosed with a heart murmur", + "Today, I’m going to talk about the most important thing in the", + ] + self.assertListEqual(expected_output_sentence, batch_out_sentence) + self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output + self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence]) + + @slow + def test_model_from_pretrained(self): + for model_name in GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = GPTJModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +@require_torch +class GPTJModelLanguageGenerationTest(unittest.TestCase): + @slow + def test_lm_generate_gptj(self): + for checkpointing in [True, False]: + model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", gradient_checkpointing=checkpointing) + model.to(torch_device) + input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog + expected_output_ids = [ + 464, + 3290, + 1528, + 286, + 3931, + 389, + 2402, + 514, + 11, + 290, + 326, + 1724, + 340, + 447, + 247, + 82, + 640, + 284, + 923, + 3612, + ] # The dog days of summer are upon us, and that means it’s time to start thinking + output_ids = model.generate(input_ids, do_sample=False) + self.assertListEqual(output_ids[0].tolist(), expected_output_ids) + + @slow + def test_gptj_sample(self): + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") + model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B") + model.to(torch_device) + + torch.manual_seed(0) + tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True) + input_ids = tokenized.input_ids.to(torch_device) + output_ids = model.generate(input_ids, do_sample=True) + output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + token_type_ids = tokenized.token_type_ids.to(torch_device) + output_seq = model.generate(input_ids=input_ids, do_sample=True, num_return_sequences=5) + output_seq_tt = model.generate( + input_ids=input_ids, token_type_ids=token_type_ids, do_sample=True, num_return_sequences=5 + ) + output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True) + output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True) + + EXPECTED_OUTPUT_STR = "Today is a nice day and I've already been enjoying it. I walked to work with my wife" + self.assertEqual(output_str, EXPECTED_OUTPUT_STR) + self.assertTrue( + all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))]) + ) # token_type_ids should change output + + @slow + def test_gptj_sample_max_time(self): + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") + model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B") + model.to(torch_device) + + torch.manual_seed(0) + tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True) + input_ids = tokenized.input_ids.to(torch_device) + + MAX_TIME = 0.5 + + start = datetime.datetime.now() + model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256) + duration = datetime.datetime.now() - start + self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) + self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) + + start = datetime.datetime.now() + model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256) + duration = datetime.datetime.now() - start + self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) + self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) + + start = datetime.datetime.now() + model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256) + duration = datetime.datetime.now() - start + self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) + self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) + + start = datetime.datetime.now() + model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256) + duration = datetime.datetime.now() - start + self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) + self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) + + start = datetime.datetime.now() + model.generate(input_ids, do_sample=False, max_time=None, max_length=256) + duration = datetime.datetime.now() - start + self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))