Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[Compat][3.11] support POP_JUMP for is None and is not None #377

Merged
merged 1 commit into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions sot/opcode_translator/executor/dispatch_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,13 @@ def operator_BAD(left, right):
pass


def operator_is_none(val):
pass


def operator_is_not_none(val):
pass


def tensor_numel(x):
pass
34 changes: 22 additions & 12 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import types
from dataclasses import dataclass
from itertools import chain
from typing import Callable, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple

import opcode

Expand Down Expand Up @@ -42,6 +42,8 @@
operator_BAD,
operator_exception_match,
operator_in,
operator_is_none,
operator_is_not_none,
operator_not_in,
)
from .dispatcher import Dispatcher
Expand Down Expand Up @@ -382,7 +384,7 @@ def inner(self: OpcodeExecutorBase, instr: Instruction):
return inner


def pop_jump_if_op_wrapper(fn: Callable[[VariableBase], bool]):
def pop_jump_if_op_wrapper(fns: list[Callable[[Any], Any]]):
"""
A decorator function that wraps a POP_JUMP_*_IF_* opcode operation and applies certain functionality to it.

Expand All @@ -406,16 +408,24 @@ def inner(self: OpcodeExecutorBase, instr: Instruction):
"""
pred_obj = self.stack.pop()

if isinstance(pred_obj, (ConstantVariable, ContainerVariable)):
try:
self._graph.add_global_guarded_variable(pred_obj)
is_jump = fn(pred_obj)
res = pred_obj
for fn in fns:
res = BuiltinVariable(
fn, graph=self._graph, tracker=DanglingTracker()
)(res)

assert isinstance(res, ConstantVariable)
is_jump = res.get_py_value()
assert isinstance(is_jump, bool)
if is_jump:
assert instr.jump_to is not None
self.jump_to(instr.jump_to)
return
raise NotImplementException(
f"Currently don't support predicate a non-const / non-tensor obj, but got {pred_obj}"
)
except BreakGraphError:
raise NotImplementException(
f"Currently don't support predicate {pred_obj.__class__.__name__}"
)

return inner

Expand Down Expand Up @@ -1460,19 +1470,19 @@ def JUMP_IF_TRUE_OR_POP(self, instr: Instruction):
"Currently don't support predicate a non-const / non-tensor obj."
)

POP_JUMP_IF_FALSE = pop_jump_if_op_wrapper(lambda x: not bool(x))
POP_JUMP_IF_FALSE = pop_jump_if_op_wrapper([bool, operator.not_])
POP_JUMP_FORWARD_IF_FALSE = POP_JUMP_IF_FALSE
POP_JUMP_BACKWARD_IF_FALSE = POP_JUMP_IF_FALSE

POP_JUMP_IF_TRUE = pop_jump_if_op_wrapper(bool)
POP_JUMP_IF_TRUE = pop_jump_if_op_wrapper([bool])
POP_JUMP_FORWARD_IF_TRUE = POP_JUMP_IF_TRUE
POP_JUMP_BACKWARD_IF_TRUE = POP_JUMP_IF_TRUE

POP_JUMP_FORWARD_IF_NONE = pop_jump_if_op_wrapper(lambda x: x.is_none())
POP_JUMP_FORWARD_IF_NONE = pop_jump_if_op_wrapper([operator_is_none])
POP_JUMP_BACKWARD_IF_NONE = POP_JUMP_FORWARD_IF_NONE

POP_JUMP_FORWARD_IF_NOT_NONE = pop_jump_if_op_wrapper(
lambda x: not x.is_none()
[operator_is_not_none]
)
POP_JUMP_BACKWARD_IF_NOT_NONE = POP_JUMP_FORWARD_IF_NOT_NONE

Expand Down
21 changes: 21 additions & 0 deletions sot/opcode_translator/executor/variable_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
)
from .dispatch_functions import (
operator_in,
operator_is_none,
operator_is_not_none,
operator_not_in,
raise_break_graph_fn,
tensor_numel,
Expand Down Expand Up @@ -691,6 +693,25 @@ def is_not_func(var: VariableBase, other: VariableBase):
return handler(var, other).bool_not()


# is None
Dispatcher.register(
operator_is_none,
("VariableBase",),
lambda var: BuiltinVariable(operator.is_, var.graph, DanglingTracker())(
var, ConstantVariable.wrap_literal(None, var.graph)
),
)

# is not None
Dispatcher.register(
operator_is_not_none,
("VariableBase",),
lambda var: BuiltinVariable(operator.is_not, var.graph, DanglingTracker())(
var, ConstantVariable.wrap_literal(None, var.graph)
),
)


# NOTE(SigureMo): Don't directly capture free var inside for-loop, use partial instead.
# ```python
# lambdas = []
Expand Down
2 changes: 0 additions & 2 deletions tests/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ failed_tests=()

py311_skiped_tests=(
./test_19_closure.py
./test_guard_user_defined_fn.py
./test_resnet.py
./test_tensor_dtype_in_guard.py
)

Expand Down