Skip to content

Commit

Permalink
fix torch.Parameter vals if no pos
Browse files Browse the repository at this point in the history
  • Loading branch information
superDong1998 committed May 8, 2024
1 parent 7d1ada8 commit 88074bc
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
12 changes: 12 additions & 0 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1896,6 +1896,18 @@ def set_if_inplace_return() -> None:
# TODO: add map and set correct partial var
return
elif is_graph_func(func):
if func is operator.getitem:
obj_var = self.state.objects.get(args[0])
assert obj_var.extract_code_at_start[0]
obj_pos = obj_var.extract_code_at_start[0]
item_pos = StoreInIndex(obj_pos, id(obj_pos), args[1])
self.state.set_partial_var({
-1: [
PartialVar(node=None,
need_guard_check=False,
extract_code_at_start=[item_pos])
]
})
return
elif len(args) > 0 and isinstance(args[0], torch.nn.ModuleList):
return
Expand Down
5 changes: 4 additions & 1 deletion frontend/variables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import abstractmethod
from typing import Any, TYPE_CHECKING, Optional, Tuple, Iterable, Callable
from copy import copy

import torch
from frontend.utils import add_force_graph_break

from ..c_api import get_miss_locals
Expand Down Expand Up @@ -90,6 +90,9 @@ def make_output(self, name_in_graph_fn: str, store_pos: StorePos,
self.make_output_inner(name_in_graph_fn, store_pos, codegen,
in_return, idx)
for attr, var in self.modified_attrs.items():
if isinstance(var.obj, torch.nn.Parameter) and len(
var.extract_code_at_start) == 0:
continue
var.make_output(f'{name_in_graph_fn}_dot_{attr}',
StoreInAttr(store_pos, id(self.obj), attr),
codegen, False, id(getattr(self.obj, attr)))
Expand Down

0 comments on commit 88074bc

Please sign in to comment.