-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
style: Fix multimodal_transformer.py with proper import structure
- Loading branch information
1 parent
e6889df
commit f78e152
Showing
2 changed files
with
122 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import re | ||
|
||
def fix_multimodal_transformer(): | ||
# Create proper class structure with fixed imports | ||
new_content = '''"""Multimodal transformer implementation.""" | ||
from pathlib import Path | ||
import logging | ||
import torch | ||
import torch.nn as nn | ||
from typing import Dict, Any, Optional, List, Union, Tuple | ||
from dataclasses import dataclass, field | ||
from src.models.layers.enhanced_transformer import EnhancedTransformer | ||
from src.models.multimodal.image_processor import ImageProcessor | ||
class MultiModalTransformer(nn.Module): | ||
"""Transformer model for multimodal inputs.""" | ||
def __init__( | ||
self, | ||
hidden_size: int = 768, | ||
num_attention_heads: int = 12, | ||
num_hidden_layers: int = 12, | ||
intermediate_size: int = 3072, | ||
hidden_dropout_prob: float = 0.1, | ||
attention_probs_dropout_prob: float = 0.1, | ||
): | ||
super().__init__() | ||
self.text_encoder = EnhancedTransformer( | ||
hidden_size=hidden_size, | ||
num_attention_heads=num_attention_heads, | ||
num_hidden_layers=num_hidden_layers, | ||
intermediate_size=intermediate_size, | ||
hidden_dropout_prob=hidden_dropout_prob, | ||
attention_probs_dropout_prob=attention_probs_dropout_prob, | ||
) | ||
self.image_processor = ImageProcessor() | ||
self.fusion_layer = nn.Linear(hidden_size * 2, hidden_size) | ||
def forward( | ||
self, | ||
text_input: torch.Tensor, | ||
image_input: torch.Tensor, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
"""Forward pass through the multimodal transformer. | ||
Args: | ||
text_input: Input text tensor | ||
image_input: Input image tensor | ||
attention_mask: Optional attention mask | ||
Returns: | ||
Tensor containing fused multimodal representations | ||
""" | ||
text_features = self.text_encoder(text_input, attention_mask) | ||
image_features = self.image_processor(image_input) | ||
combined_features = torch.cat([text_features, image_features], dim=-1) | ||
fused_features = self.fusion_layer(combined_features) | ||
return fused_features | ||
''' | ||
|
||
# Write the new content | ||
with open('src/models/multimodal/multimodal_transformer.py', 'w') as f: | ||
f.write(new_content) | ||
|
||
if __name__ == '__main__': | ||
fix_multimodal_transformer() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,89 +1,57 @@ | ||
""".""" | ||
from typing import Dict | ||
from typing import Any | ||
from typing import Optional | ||
from typing import List | ||
from typing import Union | ||
from typing import Tuple | ||
import torch | ||
import numpy as np | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data import Dataset | ||
import logging | ||
from tqdm import tqdm | ||
import os | ||
"""Multimodal transformer implementation.""" | ||
from pathlib import Path | ||
from dataclasses import dataclass | ||
from dataclasses import field | ||
from typing import Dict | ||
from typing import Any | ||
from typing import Optional | ||
from typing import List | ||
from typing import Union | ||
from typing import Tuple | ||
import torch | ||
import numpy as np | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data import Dataset | ||
import logging | ||
from tqdm import tqdm | ||
import os | ||
from pathlib import Path | ||
from dataclasses import dataclass | ||
from dataclasses import field | ||
from torch.utils.data import DataLoader | ||
from tqdm import tqdm | ||
import os | ||
from pathlib import Path import logging | ||
from typing import Dict | ||
from typing import List | ||
from typing import Optional | ||
from typing import Tuple | ||
import torch | ||
import torch.nn as nn | ||
from dataclasses import dataclass | ||
@dataclass class | ||
Module for implementing specific functionality.Module containing specific functionality.Module containing specific functionality. | ||
hidden_states_list = [] | ||
if input_ids is not None: position_ids = torch.arange(: | ||
input_ids.size(1), | ||
device=input_ids.device | ||
).unsqueeze(0) | ||
word_embeds = self.text_embeddings["word_embeddings"](input_ids) | ||
position_embeds = self.text_embeddings["position_embeddings"]( | ||
position_ids | ||
) | ||
text_hidden_states = word_embeds + position_embeds | ||
text_hidden_states = self.layernorm(text_hidden_states) | ||
text_hidden_states = self.dropout(text_hidden_states) | ||
hidden_states_list.append(text_hidden_states) | ||
if pixel_values is not None: | ||
B, C, H, W = pixel_values.shape | ||
P = self.config.patch_size | ||
patches = pixel_values.unfold(2, P, P).unfold(3, P, P) | ||
patches = patches.contiguous().view( | ||
B, C, -1, P * P | ||
).transpose(1, 2) | ||
patches = patches.reshape(B, -1, C * P * P) | ||
patch_embeds = self.image_embeddings["patch_embeddings"](patches) | ||
position_ids = torch.arange( | ||
patches.size(1), | ||
device=patches.device | ||
).unsqueeze(0) | ||
position_embeds = self.image_embeddings["position_embeddings"]( | ||
position_ids | ||
from typing import Dict, Any, Optional, List, Union, Tuple | ||
from dataclasses import dataclass, field | ||
|
||
from src.models.layers.enhanced_transformer import EnhancedTransformer | ||
from src.models.multimodal.image_processor import ImageProcessor | ||
|
||
|
||
class MultiModalTransformer(nn.Module): | ||
"""Transformer model for multimodal inputs.""" | ||
|
||
def __init__( | ||
self, | ||
hidden_size: int = 768, | ||
num_attention_heads: int = 12, | ||
num_hidden_layers: int = 12, | ||
intermediate_size: int = 3072, | ||
hidden_dropout_prob: float = 0.1, | ||
attention_probs_dropout_prob: float = 0.1, | ||
): | ||
super().__init__() | ||
self.text_encoder = EnhancedTransformer( | ||
hidden_size=hidden_size, | ||
num_attention_heads=num_attention_heads, | ||
num_hidden_layers=num_hidden_layers, | ||
intermediate_size=intermediate_size, | ||
hidden_dropout_prob=hidden_dropout_prob, | ||
attention_probs_dropout_prob=attention_probs_dropout_prob, | ||
) | ||
image_hidden_states = patch_embeds + position_embeds | ||
image_hidden_states = self.layernorm(image_hidden_states) | ||
image_hidden_states = self.dropout(image_hidden_states) | ||
hidden_states_list.append(image_hidden_states) | ||
if hidden_states_list: hidden_states = torch.cat(hidden_states_list, dim=1): | ||
if attention_mask is not None and pixel_mask is not None: attention_mask = torch.cat(: | ||
[attention_mask, pixel_mask], | ||
dim=1 | ||
) | ||
for layer in self.encoder: hidden_states = layer( | ||
hidden_states, | ||
src_key_padding_mask=attention_mask | ||
) | ||
return {"hidden_states": hidden_states} | ||
return {"hidden_states": None} | ||
self.image_processor = ImageProcessor() | ||
self.fusion_layer = nn.Linear(hidden_size * 2, hidden_size) | ||
|
||
def forward( | ||
self, | ||
text_input: torch.Tensor, | ||
image_input: torch.Tensor, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
"""Forward pass through the multimodal transformer. | ||
Args: | ||
text_input: Input text tensor | ||
image_input: Input image tensor | ||
attention_mask: Optional attention mask | ||
Returns: | ||
Tensor containing fused multimodal representations | ||
""" | ||
text_features = self.text_encoder(text_input, attention_mask) | ||
image_features = self.image_processor(image_input) | ||
combined_features = torch.cat([text_features, image_features], dim=-1) | ||
fused_features = self.fusion_layer(combined_features) | ||
return fused_features |