diff --git a/fix_multimodal_transformer.py b/fix_multimodal_transformer.py new file mode 100644 index 000000000..87638e4a9 --- /dev/null +++ b/fix_multimodal_transformer.py @@ -0,0 +1,69 @@ +import re + +def fix_multimodal_transformer(): + # Create proper class structure with fixed imports + new_content = '''"""Multimodal transformer implementation.""" +from pathlib import Path +import logging +import torch +import torch.nn as nn +from typing import Dict, Any, Optional, List, Union, Tuple +from dataclasses import dataclass, field + +from src.models.layers.enhanced_transformer import EnhancedTransformer +from src.models.multimodal.image_processor import ImageProcessor + + +class MultiModalTransformer(nn.Module): + """Transformer model for multimodal inputs.""" + + def __init__( + self, + hidden_size: int = 768, + num_attention_heads: int = 12, + num_hidden_layers: int = 12, + intermediate_size: int = 3072, + hidden_dropout_prob: float = 0.1, + attention_probs_dropout_prob: float = 0.1, + ): + super().__init__() + self.text_encoder = EnhancedTransformer( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_hidden_layers=num_hidden_layers, + intermediate_size=intermediate_size, + hidden_dropout_prob=hidden_dropout_prob, + attention_probs_dropout_prob=attention_probs_dropout_prob, + ) + self.image_processor = ImageProcessor() + self.fusion_layer = nn.Linear(hidden_size * 2, hidden_size) + + def forward( + self, + text_input: torch.Tensor, + image_input: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass through the multimodal transformer. + + Args: + text_input: Input text tensor + image_input: Input image tensor + attention_mask: Optional attention mask + + Returns: + Tensor containing fused multimodal representations + """ + text_features = self.text_encoder(text_input, attention_mask) + image_features = self.image_processor(image_input) + combined_features = torch.cat([text_features, image_features], dim=-1) + fused_features = self.fusion_layer(combined_features) + return fused_features +''' + + # Write the new content + with open('src/models/multimodal/multimodal_transformer.py', 'w') as f: + f.write(new_content) + +if __name__ == '__main__': + fix_multimodal_transformer() diff --git a/src/models/multimodal/multimodal_transformer.py b/src/models/multimodal/multimodal_transformer.py index 8aaa298ca..3813b9d6d 100644 --- a/src/models/multimodal/multimodal_transformer.py +++ b/src/models/multimodal/multimodal_transformer.py @@ -1,89 +1,57 @@ -""".""" -from typing import Dict -from typing import Any -from typing import Optional -from typing import List -from typing import Union -from typing import Tuple -import torch -import numpy as np -from torch.utils.data import DataLoader -from torch.utils.data import Dataset -import logging -from tqdm import tqdm -import os +"""Multimodal transformer implementation.""" from pathlib import Path -from dataclasses import dataclass -from dataclasses import field -from typing import Dict -from typing import Any -from typing import Optional -from typing import List -from typing import Union -from typing import Tuple -import torch -import numpy as np -from torch.utils.data import DataLoader -from torch.utils.data import Dataset import logging -from tqdm import tqdm -import os -from pathlib import Path -from dataclasses import dataclass -from dataclasses import field -from torch.utils.data import DataLoader -from tqdm import tqdm -import os -from pathlib import Path import logging -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple +import torch import torch.nn as nn -from dataclasses import dataclass -@dataclass class -Module for implementing specific functionality.Module containing specific functionality.Module containing specific functionality. -hidden_states_list = [] -if input_ids is not None: position_ids = torch.arange(: - input_ids.size(1), - device=input_ids.device - ).unsqueeze(0) - word_embeds = self.text_embeddings["word_embeddings"](input_ids) - position_embeds = self.text_embeddings["position_embeddings"]( - position_ids - ) - text_hidden_states = word_embeds + position_embeds - text_hidden_states = self.layernorm(text_hidden_states) - text_hidden_states = self.dropout(text_hidden_states) - hidden_states_list.append(text_hidden_states) - if pixel_values is not None: - B, C, H, W = pixel_values.shape - P = self.config.patch_size - patches = pixel_values.unfold(2, P, P).unfold(3, P, P) - patches = patches.contiguous().view( - B, C, -1, P * P - ).transpose(1, 2) - patches = patches.reshape(B, -1, C * P * P) - patch_embeds = self.image_embeddings["patch_embeddings"](patches) - position_ids = torch.arange( - patches.size(1), - device=patches.device - ).unsqueeze(0) - position_embeds = self.image_embeddings["position_embeddings"]( - position_ids +from typing import Dict, Any, Optional, List, Union, Tuple +from dataclasses import dataclass, field + +from src.models.layers.enhanced_transformer import EnhancedTransformer +from src.models.multimodal.image_processor import ImageProcessor + + +class MultiModalTransformer(nn.Module): + """Transformer model for multimodal inputs.""" + + def __init__( + self, + hidden_size: int = 768, + num_attention_heads: int = 12, + num_hidden_layers: int = 12, + intermediate_size: int = 3072, + hidden_dropout_prob: float = 0.1, + attention_probs_dropout_prob: float = 0.1, + ): + super().__init__() + self.text_encoder = EnhancedTransformer( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_hidden_layers=num_hidden_layers, + intermediate_size=intermediate_size, + hidden_dropout_prob=hidden_dropout_prob, + attention_probs_dropout_prob=attention_probs_dropout_prob, ) - image_hidden_states = patch_embeds + position_embeds - image_hidden_states = self.layernorm(image_hidden_states) - image_hidden_states = self.dropout(image_hidden_states) - hidden_states_list.append(image_hidden_states) - if hidden_states_list: hidden_states = torch.cat(hidden_states_list, dim=1): - if attention_mask is not None and pixel_mask is not None: attention_mask = torch.cat(: - [attention_mask, pixel_mask], - dim=1 - ) - for layer in self.encoder: hidden_states = layer( - hidden_states, - src_key_padding_mask=attention_mask - ) - return {"hidden_states": hidden_states} - return {"hidden_states": None} \ No newline at end of file + self.image_processor = ImageProcessor() + self.fusion_layer = nn.Linear(hidden_size * 2, hidden_size) + + def forward( + self, + text_input: torch.Tensor, + image_input: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass through the multimodal transformer. + + Args: + text_input: Input text tensor + image_input: Input image tensor + attention_mask: Optional attention mask + + Returns: + Tensor containing fused multimodal representations + """ + text_features = self.text_encoder(text_input, attention_mask) + image_features = self.image_processor(image_input) + combined_features = torch.cat([text_features, image_features], dim=-1) + fused_features = self.fusion_layer(combined_features) + return fused_features