From 0c1c42c120519eca74082f819a20bfe3f02fe027 Mon Sep 17 00:00:00 2001 From: Philip May Date: Mon, 26 Jul 2021 14:30:05 +0200 Subject: [PATCH] add `classifier_dropout` to classification heads (#12794) * add classifier_dropout to Electra * no type annotations yet Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * add classifier_dropout to Electra * add classifier_dropout to Electra ForTokenClass. * add classifier_dropout to bert * add classifier_dropout to roberta * add classifier_dropout to big_bird * add classifier_dropout to mobilebert * empty commit to trigger CI * add classifier_dropout to reformer * add classifier_dropout to ConvBERT * add classifier_dropout to Albert * add classifier_dropout to Albert Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/models/albert/modeling_albert.py | 7 ++++++- .../models/albert/modeling_tf_albert.py | 7 ++++++- src/transformers/models/bert/configuration_bert.py | 4 ++++ src/transformers/models/bert/modeling_bert.py | 10 ++++++++-- src/transformers/models/bert/modeling_flax_bert.py | 14 ++++++++++++-- src/transformers/models/bert/modeling_tf_bert.py | 10 ++++++++-- .../models/big_bird/configuration_big_bird.py | 4 ++++ .../models/big_bird/modeling_big_bird.py | 10 ++++++++-- .../models/big_bird/modeling_flax_big_bird.py | 14 ++++++++++++-- .../models/convbert/configuration_convbert.py | 5 ++++- .../models/convbert/modeling_convbert.py | 10 ++++++++-- .../models/convbert/modeling_tf_convbert.py | 10 ++++++++-- .../models/electra/configuration_electra.py | 4 ++++ .../models/electra/modeling_electra.py | 10 ++++++++-- .../models/electra/modeling_flax_electra.py | 14 ++++++++++++-- .../models/electra/modeling_tf_electra.py | 12 ++++++++++-- .../models/mobilebert/configuration_mobilebert.py | 5 +++++ .../models/mobilebert/modeling_mobilebert.py | 10 ++++++++-- .../models/mobilebert/modeling_tf_mobilebert.py | 10 ++++++++-- .../models/reformer/configuration_reformer.py | 4 ++++ .../models/reformer/modeling_reformer.py | 5 ++++- .../models/roberta/modeling_flax_roberta.py | 14 ++++++++++++-- .../models/roberta/modeling_roberta.py | 10 ++++++++-- .../models/roberta/modeling_tf_roberta.py | 10 ++++++++-- 24 files changed, 179 insertions(+), 34 deletions(-) diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 24b81e2ef35fdf..dd07ebebd10931 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -1088,7 +1088,12 @@ def __init__(self, config): self.num_labels = config.num_labels self.albert = AlbertModel(config, add_pooling_layer=False) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + classifier_dropout_prob = ( + config.classifier_dropout_prob + if config.classifier_dropout_prob is not None + else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout_prob) self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) self.init_weights() diff --git a/src/transformers/models/albert/modeling_tf_albert.py b/src/transformers/models/albert/modeling_tf_albert.py index c750705ee6886d..0bb3e3816e5f7c 100644 --- a/src/transformers/models/albert/modeling_tf_albert.py +++ b/src/transformers/models/albert/modeling_tf_albert.py @@ -1199,7 +1199,12 @@ def __init__(self, config: AlbertConfig, *inputs, **kwargs): self.num_labels = config.num_labels self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert") - self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) + classifier_dropout_prob = ( + config.classifier_dropout_prob + if config.classifier_dropout_prob is not None + else config.hidden_dropout_prob + ) + self.dropout = tf.keras.layers.Dropout(rate=classifier_dropout_prob) self.classifier = tf.keras.layers.Dense( units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" ) diff --git a/src/transformers/models/bert/configuration_bert.py b/src/transformers/models/bert/configuration_bert.py index 92e989c80312b5..8359f0c3b7e2b2 100644 --- a/src/transformers/models/bert/configuration_bert.py +++ b/src/transformers/models/bert/configuration_bert.py @@ -104,6 +104,8 @@ class BertConfig(PretrainedConfig): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if ``config.is_decoder=True``. + classifier_dropout (:obj:`float`, `optional`): + The dropout ratio for the classification head. Examples:: @@ -138,6 +140,7 @@ def __init__( gradient_checkpointing=False, position_embedding_type="absolute", use_cache=True, + classifier_dropout=None, **kwargs ): super().__init__(pad_token_id=pad_token_id, **kwargs) @@ -157,6 +160,7 @@ def __init__( self.gradient_checkpointing = gradient_checkpointing self.position_embedding_type = position_embedding_type self.use_cache = use_cache + self.classifier_dropout = classifier_dropout class BertOnnxConfig(OnnxConfig): diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 9606af37670253..e6244969139490 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -1486,7 +1486,10 @@ def __init__(self, config): self.config = config self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.init_weights() @@ -1677,7 +1680,10 @@ def __init__(self, config): self.num_labels = config.num_labels self.bert = BertModel(config, add_pooling_layer=False) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.init_weights() diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index 2d8d6139c3c5a3..2ec002cd3ceeee 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -915,7 +915,12 @@ class FlaxBertForSequenceClassificationModule(nn.Module): def setup(self): self.bert = FlaxBertModule(config=self.config, dtype=self.dtype) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(rate=classifier_dropout) self.classifier = nn.Dense( self.config.num_labels, dtype=self.dtype, @@ -1057,7 +1062,12 @@ class FlaxBertForTokenClassificationModule(nn.Module): def setup(self): self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(rate=classifier_dropout) self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) def __call__( diff --git a/src/transformers/models/bert/modeling_tf_bert.py b/src/transformers/models/bert/modeling_tf_bert.py index 988a6149a1cc6b..db396d70c3496a 100644 --- a/src/transformers/models/bert/modeling_tf_bert.py +++ b/src/transformers/models/bert/modeling_tf_bert.py @@ -1386,7 +1386,10 @@ def __init__(self, config: BertConfig, *inputs, **kwargs): self.num_labels = config.num_labels self.bert = TFBertMainLayer(config, name="bert") - self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = tf.keras.layers.Dropout(rate=classifier_dropout) self.classifier = tf.keras.layers.Dense( units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), @@ -1652,7 +1655,10 @@ def __init__(self, config: BertConfig, *inputs, **kwargs): self.num_labels = config.num_labels self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert") - self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = tf.keras.layers.Dropout(rate=classifier_dropout) self.classifier = tf.keras.layers.Dense( units=config.num_labels, kernel_initializer=get_initializer(config.initializer_range), diff --git a/src/transformers/models/big_bird/configuration_big_bird.py b/src/transformers/models/big_bird/configuration_big_bird.py index 18c80b1e282294..e6fdfd1d14cd97 100644 --- a/src/transformers/models/big_bird/configuration_big_bird.py +++ b/src/transformers/models/big_bird/configuration_big_bird.py @@ -84,6 +84,8 @@ class BigBirdConfig(PretrainedConfig): "block_sparse"`. gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): If True, use gradient checkpointing to save memory at the expense of slower backward pass. + classifier_dropout (:obj:`float`, `optional`): + The dropout ratio for the classification head. Example:: @@ -126,6 +128,7 @@ def __init__( block_size=64, num_random_blocks=3, gradient_checkpointing=False, + classifier_dropout=None, **kwargs ): super().__init__( @@ -157,3 +160,4 @@ def __init__( self.use_bias = use_bias self.block_size = block_size self.num_random_blocks = num_random_blocks + self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 429ac39f86e4fc..15aef2b789a4fa 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -2605,7 +2605,10 @@ class BigBirdClassificationHead(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) self.out_proj = nn.Linear(config.hidden_size, config.num_labels) self.config = config @@ -2821,7 +2824,10 @@ def __init__(self, config): self.num_labels = config.num_labels self.bert = BigBirdModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.init_weights() diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py index edbac4aab1b319..20526cc14ce9ec 100644 --- a/src/transformers/models/big_bird/modeling_flax_big_bird.py +++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py @@ -1654,7 +1654,12 @@ class FlaxBigBirdClassificationHead(nn.Module): def setup(self): self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) - self.dropout = nn.Dropout(self.config.hidden_dropout_prob) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) self.out_proj = nn.Dense(self.config.num_labels, dtype=self.dtype) def __call__(self, features, deterministic=True): @@ -1831,7 +1836,12 @@ class FlaxBigBirdForTokenClassificationModule(nn.Module): def setup(self): self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(rate=classifier_dropout) self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) def __call__( diff --git a/src/transformers/models/convbert/configuration_convbert.py b/src/transformers/models/convbert/configuration_convbert.py index ef4df0ee5632ca..1f904ddfcee2d1 100644 --- a/src/transformers/models/convbert/configuration_convbert.py +++ b/src/transformers/models/convbert/configuration_convbert.py @@ -73,7 +73,8 @@ class ConvBertConfig(PretrainedConfig): The number of groups for grouped linear layers for ConvBert model conv_kernel_size (:obj:`int`, `optional`, defaults to 9): The size of the convolutional kernel. - + classifier_dropout (:obj:`float`, `optional`): + The dropout ratio for the classification head. Example:: >>> from transformers import ConvBertModel, ConvBertConfig @@ -108,6 +109,7 @@ def __init__( head_ratio=2, conv_kernel_size=9, num_groups=1, + classifier_dropout=None, **kwargs, ): super().__init__( @@ -134,3 +136,4 @@ def __init__( self.head_ratio = head_ratio self.conv_kernel_size = conv_kernel_size self.num_groups = num_groups + self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index 09d3d0db8faaed..3625f65ff66d82 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -936,7 +936,10 @@ class ConvBertClassificationHead(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) self.out_proj = nn.Linear(config.hidden_size, config.num_labels) self.config = config @@ -1152,7 +1155,10 @@ def __init__(self, config): self.num_labels = config.num_labels self.convbert = ConvBertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.init_weights() diff --git a/src/transformers/models/convbert/modeling_tf_convbert.py b/src/transformers/models/convbert/modeling_tf_convbert.py index f088db5ad16839..a1fd6cf4600a87 100644 --- a/src/transformers/models/convbert/modeling_tf_convbert.py +++ b/src/transformers/models/convbert/modeling_tf_convbert.py @@ -970,7 +970,10 @@ def __init__(self, config, **kwargs): self.dense = tf.keras.layers.Dense( config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" ) - self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = tf.keras.layers.Dropout(classifier_dropout) self.out_proj = tf.keras.layers.Dense( config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" ) @@ -1240,7 +1243,10 @@ def __init__(self, config, *inputs, **kwargs): self.num_labels = config.num_labels self.convbert = TFConvBertMainLayer(config, name="convbert") - self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = tf.keras.layers.Dropout(classifier_dropout) self.classifier = tf.keras.layers.Dense( config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" ) diff --git a/src/transformers/models/electra/configuration_electra.py b/src/transformers/models/electra/configuration_electra.py index b8bae422c049bd..f4dd18bf270c6b 100644 --- a/src/transformers/models/electra/configuration_electra.py +++ b/src/transformers/models/electra/configuration_electra.py @@ -104,6 +104,8 @@ class ElectraConfig(PretrainedConfig): `__. For more information on :obj:`"relative_key_query"`, please refer to `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.) `__. + classifier_dropout (:obj:`float`, `optional`): + The dropout ratio for the classification head. Examples:: @@ -141,6 +143,7 @@ def __init__( summary_last_dropout=0.1, pad_token_id=0, position_embedding_type="absolute", + classifier_dropout=None, **kwargs ): super().__init__(pad_token_id=pad_token_id, **kwargs) @@ -164,3 +167,4 @@ def __init__( self.summary_activation = summary_activation self.summary_last_dropout = summary_last_dropout self.position_embedding_type = position_embedding_type + self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index aa41b45676354c..c4366f568483ec 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -900,7 +900,10 @@ class ElectraClassificationHead(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) self.out_proj = nn.Linear(config.hidden_size, config.num_labels) def forward(self, features, **kwargs): @@ -1200,7 +1203,10 @@ def __init__(self, config): super().__init__(config) self.electra = ElectraModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.init_weights() diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index d5212851c802db..cbd7b00c6eb308 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -783,7 +783,12 @@ class FlaxElectraForTokenClassificationModule(nn.Module): def setup(self): self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) - self.dropout = nn.Dropout(self.config.hidden_dropout_prob) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Dense(self.config.num_labels) def __call__( @@ -1069,7 +1074,12 @@ class FlaxElectraClassificationHead(nn.Module): def setup(self): self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) - self.dropout = nn.Dropout(self.config.hidden_dropout_prob) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) self.out_proj = nn.Dense(self.config.num_labels, dtype=self.dtype) def __call__(self, hidden_states, deterministic: bool = True): diff --git a/src/transformers/models/electra/modeling_tf_electra.py b/src/transformers/models/electra/modeling_tf_electra.py index 2383df177a95e4..878395af00bc71 100644 --- a/src/transformers/models/electra/modeling_tf_electra.py +++ b/src/transformers/models/electra/modeling_tf_electra.py @@ -1039,7 +1039,12 @@ def __init__(self, config, **kwargs): self.dense = tf.keras.layers.Dense( config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" ) - self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifhidden_dropout_probier_dropout + if config.classifier_dropout is not None + else config.hidden_dropout_prob + ) + self.dropout = tf.keras.layers.Dropout(classifier_dropout) self.out_proj = tf.keras.layers.Dense( config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" ) @@ -1309,7 +1314,10 @@ def __init__(self, config, **kwargs): super().__init__(config, **kwargs) self.electra = TFElectraMainLayer(config, name="electra") - self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = tf.keras.layers.Dropout(classifier_dropout) self.classifier = tf.keras.layers.Dense( config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" ) diff --git a/src/transformers/models/mobilebert/configuration_mobilebert.py b/src/transformers/models/mobilebert/configuration_mobilebert.py index aaafd7a37bef58..4f8e338d33aa9e 100644 --- a/src/transformers/models/mobilebert/configuration_mobilebert.py +++ b/src/transformers/models/mobilebert/configuration_mobilebert.py @@ -84,6 +84,8 @@ class MobileBertConfig(PretrainedConfig): Number of FFNs in a block. normalization_type (:obj:`str`, `optional`, defaults to :obj:`"no_norm"`): The normalization type in MobileBERT. + classifier_dropout (:obj:`float`, `optional`): + The dropout ratio for the classification head. Examples:: @@ -128,6 +130,7 @@ def __init__( num_feedforward_networks=4, normalization_type="no_norm", classifier_activation=True, + classifier_dropout=None, **kwargs ): super().__init__(pad_token_id=pad_token_id, **kwargs) @@ -158,3 +161,5 @@ def __init__( self.true_hidden_size = intra_bottleneck_size else: self.true_hidden_size = hidden_size + + self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index 448a894beb8d29..4abf50491821c6 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -1212,7 +1212,10 @@ def __init__(self, config): self.config = config self.mobilebert = MobileBertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.init_weights() @@ -1510,7 +1513,10 @@ def __init__(self, config): self.num_labels = config.num_labels self.mobilebert = MobileBertModel(config, add_pooling_layer=False) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.init_weights() diff --git a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py index 0a103b54f6109e..920086d6d6e1ee 100644 --- a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py @@ -1339,7 +1339,10 @@ def __init__(self, config, *inputs, **kwargs): self.num_labels = config.num_labels self.mobilebert = TFMobileBertMainLayer(config, name="mobilebert") - self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = tf.keras.layers.Dropout(classifier_dropout) self.classifier = tf.keras.layers.Dense( config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" ) @@ -1730,7 +1733,10 @@ def __init__(self, config, *inputs, **kwargs): self.num_labels = config.num_labels self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert") - self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = tf.keras.layers.Dropout(classifier_dropout) self.classifier = tf.keras.layers.Dense( config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" ) diff --git a/src/transformers/models/reformer/configuration_reformer.py b/src/transformers/models/reformer/configuration_reformer.py index 1f283b970887ee..b48fadfb32d909 100755 --- a/src/transformers/models/reformer/configuration_reformer.py +++ b/src/transformers/models/reformer/configuration_reformer.py @@ -140,6 +140,8 @@ class ReformerConfig(PretrainedConfig): Whether to tie input and output embeddings. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). + classifier_dropout (:obj:`float`, `optional`): + The dropout ratio for the classification head. Examples:: @@ -191,6 +193,7 @@ def __init__( vocab_size=320, tie_word_embeddings=False, use_cache=True, + classifier_dropout=None, **kwargs ): super().__init__( @@ -230,3 +233,4 @@ def __init__( self.chunk_size_lm_head = chunk_size_lm_head self.attn_layers = attn_layers self.use_cache = use_cache + self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 8521a9542b6d44..77f35b97faaf16 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -2474,7 +2474,10 @@ class ReformerClassificationHead(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) self.out_proj = nn.Linear(config.hidden_size, config.num_labels) def forward(self, hidden_states, **kwargs): diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index 3cfa430dd18d68..da6e300905ad3f 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -498,7 +498,12 @@ def setup(self): dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), ) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(rate=classifier_dropout) self.out_proj = nn.Dense( self.config.num_labels, dtype=self.dtype, @@ -877,7 +882,12 @@ class FlaxRobertaForTokenClassificationModule(nn.Module): def setup(self): self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) - self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(rate=classifier_dropout) self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) def __call__( diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index b8228fa6f5cea4..fc377e8e40ca2c 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -1346,7 +1346,10 @@ def __init__(self, config): self.num_labels = config.num_labels self.roberta = RobertaModel(config, add_pooling_layer=False) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.init_weights() @@ -1427,7 +1430,10 @@ class RobertaClassificationHead(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) self.out_proj = nn.Linear(config.hidden_size, config.num_labels) def forward(self, features, **kwargs): diff --git a/src/transformers/models/roberta/modeling_tf_roberta.py b/src/transformers/models/roberta/modeling_tf_roberta.py index 6439d010412cf9..41f112a11e394c 100644 --- a/src/transformers/models/roberta/modeling_tf_roberta.py +++ b/src/transformers/models/roberta/modeling_tf_roberta.py @@ -933,7 +933,10 @@ def __init__(self, config, **kwargs): activation="tanh", name="dense", ) - self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = tf.keras.layers.Dropout(classifier_dropout) self.out_proj = tf.keras.layers.Dense( config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" ) @@ -1206,7 +1209,10 @@ def __init__(self, config, *inputs, **kwargs): self.num_labels = config.num_labels self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") - self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = tf.keras.layers.Dropout(classifier_dropout) self.classifier = tf.keras.layers.Dense( config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" )