Skip to content

Commit

Permalink
[SOT][3.13] Support closure (#69753)
Browse files Browse the repository at this point in the history
  • Loading branch information
gouzil authored Nov 28, 2024
1 parent 6c1c486 commit 9d49a11
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
SymbolicVariable,
TensorVariable,
TupleVariable,
UserCodeVariable,
UserDefinedFunctionVariable,
UserDefinedGeneratorFunctionVariable,
VariableBase,
Expand Down Expand Up @@ -1384,14 +1385,28 @@ def MAKE_FUNCTION(self, instr: Instruction):

# the function has no default values in 3.13
if sys.version_info >= (3, 13):
flag = 0
else:
flag = instr.arg
if len(codeobj.get_py_value().co_freevars) > 0:
self.stack.push(
UserCodeVariable(
codeobj, self._graph, DummyTracker([codeobj])
)
)
else:
self.push_new_fn_on_stack(
codeobj.get_py_value(),
global_dict,
fn_name.get_py_value(),
(),
(),
related_list,
(),
)
return

flag = instr.arg
closure, related_list, kw_defaults, default_args = (
self.attach_new_attribute(flag, related_list)
)

self.push_new_fn_on_stack(
codeobj.get_py_value(),
global_dict,
Expand All @@ -1404,13 +1419,36 @@ def MAKE_FUNCTION(self, instr: Instruction):

def SET_FUNCTION_ATTRIBUTE(self, instr: Instruction):
origin_func = self.stack.pop()
flag = instr.arg

if isinstance(origin_func, UserCodeVariable):
origin_codeobj = origin_func.codeobj
fn_name = ConstantVariable(
origin_codeobj.value.co_qualname,
self._graph,
DummyTracker([origin_codeobj]),
)
related_list = [fn_name, origin_codeobj]
closure, related_list, kw_defaults, default_args = (
self.attach_new_attribute(flag, related_list)
)
self.push_new_fn_on_stack(
origin_codeobj.get_py_value(),
self._globals.get_value(),
fn_name.get_py_value(),
default_args,
closure,
related_list,
kw_defaults,
)
return

# The object we manipulate must be a functionVariable
assert isinstance(
origin_func,
(UserDefinedGeneratorFunctionVariable, UserDefinedFunctionVariable),
), f"The object we manipulate must be a function object. But now got {type(origin_func)}"
origin_func_val = origin_func.get_py_value()
flag = instr.arg
related_list = [origin_func]
closure, related_list, kw_defaults, default_args = (
self.attach_new_attribute(flag, related_list)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
MethodVariable,
PaddleApiVariable,
PaddleLayerVariable,
UserCodeVariable,
UserDefinedFunctionVariable,
UserDefinedGeneratorFunctionVariable,
UserDefinedLayerVariable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from ....utils.exceptions import (
BreakGraphError,
FallbackError,
InnerError,
SotErrorBase,
)
from ..dispatcher import Dispatcher
Expand All @@ -65,7 +66,12 @@
Tracker,
)
from .base import VariableBase, VariableFactory
from .basic import ConstantVariable, PrintStmtVariable, SliceVariable
from .basic import (
ConstantVariable,
ObjectVariable,
PrintStmtVariable,
SliceVariable,
)

if TYPE_CHECKING:
from ..function_graph import FunctionGraph
Expand Down Expand Up @@ -230,6 +236,22 @@ def main_info(self) -> dict[str, Any]:
}


class UserCodeVariable(FunctionVariable):
"""
UserCodeVariable is a subclass of Function
Variable used to wrap a make function variable.
"""

def __init__(
self, codeobj: ObjectVariable, graph: FunctionGraph, tracker: Tracker
):
super().__init__(codeobj, graph, tracker)
self.codeobj = codeobj

def call_function(self, /, *args, **kwargs):
raise InnerError("UserCodeVariable call_function is not implemented.")


class PaddleApiVariable(FunctionVariable):
"""
PaddleApiVariable is a subclass of FunctionVariable used to wrap a paddlepaddle API function.
Expand Down
1 change: 0 additions & 1 deletion test/sot/skip_files_py313

This file was deleted.

3 changes: 1 addition & 2 deletions test/sot/test_13_make_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ def test_simple(self):
self.assert_results(make_fn_default, paddle.to_tensor(1))
self.assert_results(make_fn_annotation, paddle.to_tensor(1))
self.assert_results(make_fn_kwdefault, paddle.to_tensor(1))
# self.assert_results(make_fn_closure, paddle.to_tensor(1))
# we haven't pass this test yet
self.assert_results(make_fn_closure, paddle.to_tensor(1))
self.assert_results(make_fn_mix, paddle.to_tensor(1))


Expand Down

0 comments on commit 9d49a11

Please sign in to comment.