Skip to content

Commit

Permalink
Refactored BertEmbedding Layer (#69)
Browse files Browse the repository at this point in the history
Co-authored-by: Mohammed <[email protected]>
Co-authored-by: JG <[email protected]>
  • Loading branch information
3 people authored Aug 9, 2023
1 parent 288814b commit 34e938e
Showing 1 changed file with 4 additions and 49 deletions.
53 changes: 4 additions & 49 deletions ivy_models/bert/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,6 @@
import math


class Embedding(ivy.Module):
def __init__(
self,
num_embedding: int,
embedding_dim: int,
padding_idx=None,
max_norm=None,
initializer=None,
v=None,
):
self.num_embedding = num_embedding
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.W_initializer = (
initializer
if initializer is not None
else ivy.stateful.initializers.GlorotUniform()
)
super(Embedding, self).__init__(v=v)

def _create_variables(self, device, dtype=None):
v = {
"weight": self.W_initializer.create_variables(
(self.num_embedding, self.embedding_dim),
device,
self.embedding_dim,
self.num_embedding,
dtype=dtype,
)
}
return v

def _mask_embed(self, indices, embed):
mask = ivy.expand_dims(indices == self.padding_idx, axis=-1)
mask_val = ivy.array(0.0, dtype=embed.dtype)
return ivy.where(mask, mask_val, embed)

def _forward(self, indices):
emb = ivy.embedding(self.v.weight, indices, max_norm=self.max_norm)
if self.padding_idx is not None:
emb = self._mask_embed(indices, emb)
return emb


class BertEmbedding(ivy.Module):
def __init__(
self,
Expand All @@ -71,13 +26,13 @@ def __init__(
super(BertEmbedding, self).__init__(v=v)

def _build(self, *args, **kwargs):
self.word_embeddings = Embedding(
self.vocab_size, self.hidden_size, padding_idx=self.padding_idx
self.word_embeddings = ivy.Embedding(
self.vocab_size, self.hidden_size, self.padding_idx
)
self.position_embeddings = Embedding(
self.position_embeddings = ivy.Embedding(
self.max_position_embeddings, self.hidden_size
)
self.token_type_embeddings = Embedding(self.type_token_size, self.hidden_size)
self.token_type_embeddings = ivy.Embedding(self.type_token_size, self.hidden_size)
self.dropout = ivy.Dropout(self.drop_rate)
self.LayerNorm = ivy.LayerNorm([self.hidden_size], eps=self.layer_norm_eps)

Expand Down

0 comments on commit 34e938e

Please sign in to comment.