Skip to content

Commit

Permalink
[Bugfix] remove post_layernorm in siglip (#8106)
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz authored Sep 4, 2024
1 parent ccd7207 commit d331156
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions vllm/model_executor/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,14 +443,27 @@ def __init__(
self.config = config
embed_dim = config.hidden_size

if (num_hidden_layers_override is None
or num_hidden_layers_override == config.num_hidden_layers):
self.need_post_layernorm = True
elif num_hidden_layers_override > config.num_hidden_layers:
raise ValueError(
"num_hidden_layers_override cannot be greater than "
"num_hidden_layers")
else:
self.need_post_layernorm = False

self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(
config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
)
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
if self.need_post_layernorm:
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
else:
self.post_layernorm = nn.Identity()
self.use_head = (True if not hasattr(config, "vision_use_head") else
config.vision_use_head)
if self.use_head:
Expand All @@ -470,7 +483,6 @@ def forward(
encoder_outputs = self.encoder(inputs_embeds=hidden_states)

last_hidden_state = self.post_layernorm(encoder_outputs)

# TODO: add this back when pooled_output is used in inference
# if self.use_head:
# pooled_output = self.head(last_hidden_state)
Expand Down Expand Up @@ -499,6 +511,10 @@ def __init__(
num_hidden_layers_override=num_hidden_layers_override,
)

@property
def need_post_layernorm(self):
return self.vision_model.need_post_layernorm

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding

Expand All @@ -517,6 +533,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
layer_count = len(self.vision_model.encoder.layers)

for name, loaded_weight in weights:
# post_layernorm is optional in SiglipVisionModel
if ("vision_model.post_layernorm" in name
and not self.need_post_layernorm):
continue

# omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name:
layer_idx = int(name.split(".")[3])
Expand Down

0 comments on commit d331156

Please sign in to comment.