Skip to content

Commit

Permalink
revert: put deduplicated LayerScale class back in gen_model module
Browse files Browse the repository at this point in the history
  • Loading branch information
lmmx committed Jan 2, 2024
1 parent aca9701 commit ab47953
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/uform/gen_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import BatchEncoding

from .models.encoders import LayerScale, VisualEncoder
from .models.encoders import VisualEncoder
from .models.image_utils import convert_to_rgb

IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073)
IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711)

__all__ = [
"LayerScale",
"ImageFeaturesPooler",
"VLMConfig",
"VLMPreTrainedModel",
Expand All @@ -35,6 +36,16 @@
]


class LayerScale(nn.Module):
def __init__(self, dim, init_values: float = 1e-5, inplace: bool = False):
super().__init__()
self.weight = nn.Parameter(init_values * torch.ones(dim))
self.inplace = inplace

def forward(self, x):
return x.mul_(self.weight) if self.inplace else x * self.weight


class ImageFeaturesPooler(nn.Module):
def __init__(
self,
Expand Down

0 comments on commit ab47953

Please sign in to comment.