Skip to content

Commit

Permalink
[SOME PROGRESS]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Feb 15, 2024
1 parent de7ec0f commit 83221d8
Showing 1 changed file with 101 additions and 52 deletions.
153 changes: 101 additions & 52 deletions limoe/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,40 @@
import torch
from torch import nn, Tensor, einsum
from zeta.nn import MoERouter, MixtureOfExperts, FeedForward, PreNorm, Attention
import torch
from torch import nn, Tensor
from zeta.nn import (
MixtureOfExperts,
Attention,
)


class DenseEncoderLayer(nn.Module):
"""
DenseEncoderLayer is a module that represents a single layer of a dense encoder.
Args:
dim (int): The input dimension.
depth (int): The depth of the encoder layer.
heads (int): The number of attention heads.
num_experts (int): The number of experts in the mixture of experts.
dim_head (int): The dimension of each attention head.
dropout (int): The dropout rate.
ff_mult (int): The multiplier for the feed-forward network dimension.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Attributes:
dim (int): The input dimension.
depth (int): The depth of the encoder layer.
num_experts (int): The number of experts in the mixture of experts.
dim_head (int): The dimension of each attention head.
dropout (int): The dropout rate.
ff_mult (int): The multiplier for the feed-forward network dimension.
heads (int): The number of attention heads.
scale (float): The scaling factor for the attention weights.
experts (MixtureOfExperts): The mixture of experts module.
attn (Attention): The attention module.
"""

def __init__(
self,
dim: int,
Expand All @@ -14,7 +45,7 @@ def __init__(
dropout: int,
ff_mult: int,
*args,
**kwargs
**kwargs,
):
super().__init__()
self.dim = dim
Expand All @@ -23,21 +54,21 @@ def __init__(
self.dim_head = dim_head
self.dropout = dropout
self.ff_mult = ff_mult

self.heads = self.dim // self.dim_head
self.scale = self.dim_head ** -0.5
self.scale = self.dim_head**-0.5

gpu = "cuda" if torch.cuda.is_available() else "cpu"

# Experts
self.experts = MixtureOfExperts(
dim=self.dim,
num_experts=self.num_experts,
dim_head=self.dim_head,
# dim_head=self.dim_head,
dropout=self.dropout,
ff_mult=ff_mult
ff_mult=ff_mult,
)

# Attention
self.attn = Attention(
dim,
Expand All @@ -47,56 +78,74 @@ def __init__(
flash=gpu,
qk_norm=True,
*args,
**kwargs
**kwargs,
)

def forward(self, x: Tensor):
"""
Forward pass of the DenseEncoderLayer.
Args:
x (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
# Attention
x = self.attn(x)
x, _ = self.attn(x)

# Expert
x = self.experts(x)

return x



# Tensor
x = torch.randn(1, 64, 512)
model = DenseEncoderLayer(512, 4, 8, 4, 64, 0.1, 4)
model = DenseEncoderLayer(
dim=512,
depth=3,
heads=8,
dim_head=64,
num_experts=4,
# dim_head = 64,
dropout=0.1,
ff_mult=4,
)
print(model(x).shape)


# LiMoE: Linear Mixture of Experts
class LiMoE(nn.Module):
def __init__(
self,
dim: int,
depth: int,
num_experts: int,
dim_head: int,
dropout: float,
*args,
**kwargs
):
super().__init__()
self.dim = dim
self.depth = depth
self.num_experts = num_experts
self.dim_head = dim_head
self.dropout = dropout
self.heads = self.dim // self.dim_head
self.scale = self.dim_head ** -0.5

def forward(self, x: Tensor):
# Encoder
for _ in range(self.depth):
x = DenseEncoderLayer(
dim=self.dim,
depth=self.depth,
num_experts=self.num_experts,
dim_head=self.dim_head,
dropout=self.dropout
)(x)

return x


# class LiMoE(nn.Module):
# def __init__(
# self,
# dim: int,
# depth: int,
# num_experts: int,
# dim_head: int,
# dropout: float,
# *args,
# **kwargs,
# ):
# super().__init__()
# self.dim = dim
# self.depth = depth
# self.num_experts = num_experts
# self.dim_head = dim_head
# self.dropout = dropout
# self.heads = self.dim // self.dim_head
# self.scale = self.dim_head**-0.5

# def forward(self, x: Tensor):
# # Encoder
# for _ in range(self.depth):
# x = DenseEncoderLayer(
# dim=self.dim,
# depth=self.depth,
# num_experts=self.num_experts,
# dim_head=self.dim_head,
# dropout=self.dropout,
# )(x)

# return x

0 comments on commit 83221d8

Please sign in to comment.