diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index b11856727d54..64e4d4ea5ec6 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -1896,6 +1896,18 @@ def set_if_inplace_return() -> None: # TODO: add map and set correct partial var return elif is_graph_func(func): + if func is operator.getitem: + obj_var = self.state.objects.get(args[0]) + assert obj_var.extract_code_at_start[0] + obj_pos = obj_var.extract_code_at_start[0] + item_pos = StoreInIndex(obj_pos, id(obj_pos), args[1]) + self.state.set_partial_var({ + -1: [ + PartialVar(node=None, + need_guard_check=False, + extract_code_at_start=[item_pos]) + ] + }) return elif len(args) > 0 and isinstance(args[0], torch.nn.ModuleList): return diff --git a/frontend/variables/base.py b/frontend/variables/base.py index 2278e0d3cc86..e4e5ba1b6beb 100644 --- a/frontend/variables/base.py +++ b/frontend/variables/base.py @@ -2,7 +2,7 @@ from abc import abstractmethod from typing import Any, TYPE_CHECKING, Optional, Tuple, Iterable, Callable from copy import copy - +import torch from frontend.utils import add_force_graph_break from ..c_api import get_miss_locals @@ -90,6 +90,9 @@ def make_output(self, name_in_graph_fn: str, store_pos: StorePos, self.make_output_inner(name_in_graph_fn, store_pos, codegen, in_return, idx) for attr, var in self.modified_attrs.items(): + if isinstance(var.obj, torch.nn.Parameter) and len( + var.extract_code_at_start) == 0: + continue var.make_output(f'{name_in_graph_fn}_dot_{attr}', StoreInAttr(store_pos, id(self.obj), attr), codegen, False, id(getattr(self.obj, attr)))