-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: resolve indentation and syntax issues in text_to_anything.py
- Loading branch information
1 parent
dc9fcd5
commit 63a6069
Showing
2 changed files
with
81 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
def fix_text_to_anything(): | ||
with open('src/models/text_to_anything.py', 'r') as f: | ||
content = f.readlines() | ||
|
||
# Add missing imports at the top | ||
imports = [ | ||
"import jax.numpy as jnp\n", | ||
"from typing import Dict, List, Optional, Tuple, Union\n", | ||
"from flax import linen as nn\n" | ||
] | ||
|
||
# Find where to insert imports | ||
for i, line in enumerate(content): | ||
if line.startswith("from flax import struct"): | ||
content = content[:i] + imports + content[i:] | ||
break | ||
|
||
# Initialize variables properly | ||
fixed_content = [] | ||
in_call_method = False | ||
batch_size_initialized = False | ||
|
||
for i, line in enumerate(content): | ||
# Skip the original imports we're replacing | ||
if any(imp in line for imp in ["import jax", "from typing import", "from flax import linen"]): | ||
continue | ||
|
||
# Track when we're in the __call__ method | ||
if "def __call__" in line: | ||
in_call_method = True | ||
|
||
if in_call_method and "encodings = {}" in line: | ||
fixed_content.append(line) | ||
# Add batch size initialization with proper indentation | ||
fixed_content.append(" batch_size = 1 # Initialize with default value\n") | ||
batch_size_initialized = True | ||
continue | ||
|
||
# Fix the commented out batch_size assignments | ||
if line.strip().startswith("#") and "curr_batch_size" in line: | ||
# Remove comment and TODO, maintain indentation | ||
spaces = len(line) - len(line.lstrip()) | ||
clean_line = line[line.index("curr_batch_size"):].strip() | ||
clean_line = clean_line.replace("# TODO: Remove or use this variable", "") | ||
fixed_content.append(" " * spaces + clean_line + "\n") | ||
continue | ||
|
||
# Fix indentation after if batch_size is None | ||
if "if batch_size is None:" in line: | ||
fixed_content.append(line) | ||
next_line = content[i + 1] | ||
if "#" in next_line and "batch_size = curr_batch_size" in next_line: | ||
spaces = len(line) - len(line.lstrip()) + 4 # Add 4 spaces for indentation | ||
fixed_content.append(" " * spaces + "batch_size = curr_batch_size\n") | ||
continue | ||
|
||
# Fix the sequence length adjustment indentation | ||
if "_adjust_sequence_length" in line and "embedded" in line: | ||
spaces = len(line) - len(line.lstrip()) | ||
fixed_content.append(" " * spaces + "embedded = self._adjust_sequence_length(\n") | ||
fixed_content.append(" " * (spaces + 4) + "embedded,\n") | ||
fixed_content.append(" " * (spaces + 4) + "sequence_length\n") | ||
fixed_content.append(" " * spaces + ")\n") | ||
continue | ||
|
||
if not batch_size_initialized or line.strip() != "": | ||
fixed_content.append(line) | ||
|
||
# Write the fixed content | ||
with open('src/models/text_to_anything.py', 'w') as f: | ||
f.writelines(fixed_content) | ||
|
||
if __name__ == "__main__": | ||
fix_text_to_anything() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters