-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: resolve syntax and indentation issues in text_to_anything.py
- Loading branch information
1 parent
63a6069
commit 2b0b730
Showing
2 changed files
with
112 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import re | ||
|
||
def fix_text_to_anything(): | ||
with open('src/models/text_to_anything.py', 'r') as f: | ||
content = f.readlines() | ||
|
||
# Add missing imports if not present | ||
imports = [ | ||
"import jax.numpy as jnp\n", | ||
"from typing import Dict, List, Optional, Tuple, Union, Any\n", | ||
"from flax import linen as nn\n" | ||
] | ||
|
||
# Find where to insert imports | ||
for i, line in enumerate(content): | ||
if line.startswith("from flax import struct"): | ||
content = content[:i] + imports + content[i:] | ||
break | ||
|
||
# Fix the content | ||
fixed_content = [] | ||
in_call_method = False | ||
batch_size_initialized = False | ||
skip_next_lines = 0 | ||
|
||
for i, line in enumerate(content): | ||
# Skip lines if needed | ||
if skip_next_lines > 0: | ||
skip_next_lines -= 1 | ||
continue | ||
|
||
# Skip duplicate imports | ||
if any(imp in line for imp in ["import jax", "from typing import", "from flax import linen"]): | ||
continue | ||
|
||
# Track when we're in __call__ method | ||
if "def __call__" in line: | ||
in_call_method = True | ||
# Fix the method signature | ||
fixed_content.append(" def __call__(") | ||
fixed_content.append(" self,") | ||
fixed_content.append(" inputs: Union[str, Dict[str, Any]],") | ||
fixed_content.append(" target_modality: str,") | ||
fixed_content.append(" context: Optional[Dict[str, Any]] = None,") | ||
fixed_content.append(" training: bool = False") | ||
fixed_content.append(" ) -> Tuple[jnp.ndarray, Dict[str, Any]]:\n") | ||
skip_next_lines = 9 # Skip the original malformed signature | ||
continue | ||
|
||
# Remove duplicate batch_size initialization | ||
if "batch_size = 1" in line and batch_size_initialized: | ||
continue | ||
|
||
if "batch_size = 1" in line and not batch_size_initialized: | ||
fixed_content.append(" batch_size = 1 # Initialize with default value\n") | ||
batch_size_initialized = True | ||
continue | ||
|
||
# Fix curr_batch_size assignments | ||
if "curr_batch_size" in line: | ||
# Remove extra spaces and fix indentation | ||
stripped = line.lstrip() | ||
if stripped.startswith("#"): | ||
continue | ||
spaces = " " if in_call_method else " " | ||
fixed_content.append(f"{spaces}{stripped}") | ||
continue | ||
|
||
# Fix duplicate _adjust_sequence_length calls | ||
if "_adjust_sequence_length" in line: | ||
if "embedded = self._adjust_sequence_length(" in line: | ||
fixed_content.append(" embedded = self._adjust_sequence_length(\n") | ||
fixed_content.append(" embedded,\n") | ||
fixed_content.append(" sequence_length\n") | ||
fixed_content.append(" )\n") | ||
skip_next_lines = 6 # Skip the duplicate call | ||
continue | ||
|
||
# Add the line if it's not being skipped | ||
if line.strip(): | ||
fixed_content.append(line) | ||
|
||
# Write the fixed content | ||
with open('src/models/text_to_anything.py', 'w') as f: | ||
f.writelines(fixed_content) | ||
|
||
if __name__ == "__main__": | ||
fix_text_to_anything() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters