From 4ca6ed0d20694e98aa02488675afb771f8bf47b4 Mon Sep 17 00:00:00 2001 From: Richard Si Date: Sun, 12 Mar 2023 22:06:32 -0400 Subject: [PATCH] [mypyc] Use table-driven helper for imports Change how imports (not from imports!) are processed so they can be table-driven and compact. Here's how it works: Import nodes are divided in groups (in the prebuild visitor). Each group consists of consecutive Import nodes: import mod <| group #1 import mod2 | def foo() -> None: import mod3 <- group #2 import mod4 <- group #3 Every time we encounter the first import of a group, build IR to call CPyImport_ImportMany() that will perform all of the group's imports in one go. Previously, each module would imported and placed in globals manually in IR, leading to some pretty verbose code. The other option to collect all imports and perform them all at once in the helper would remove even more ops, however, it's problematic for the same reasons from the previous commit (spoiler: it's not safe). Implementation notes: - I had to add support for loading the address of a static directly, so I shoehorned in LoadLiteral support for LoadAddress. - Unfortunately by replacing multiple nodes with a single function call at the IR level, the traceback line number is static. Even if an import several lines down a group fails, the line # of the first import in the group would be printed. To fix this, I had to make CPyImport_ImportMany() add the traceback entry itself on failure (instead of letting codegen handle it automatically). This is admittedly ugly. --- mypyc/codegen/emitfunc.py | 8 +- mypyc/ir/ops.py | 5 +- mypyc/ir/pprint.py | 5 + mypyc/irbuild/builder.py | 2 + mypyc/irbuild/ll_builder.py | 6 ++ mypyc/irbuild/prebuildvisitor.py | 25 ++++- mypyc/irbuild/statement.py | 88 +++++++++++----- mypyc/lib-rt/CPy.h | 2 + mypyc/lib-rt/misc_ops.c | 62 ++++++++++++ mypyc/primitives/misc_ops.py | 18 +++- mypyc/test-data/irbuild-basic.test | 155 ++++++++++++++++++++++------- mypyc/test-data/run-imports.test | 43 ++++++++ 12 files changed, 352 insertions(+), 67 deletions(-) diff --git a/mypyc/codegen/emitfunc.py b/mypyc/codegen/emitfunc.py index 3b544d5165dc7..c1ad0940d768e 100644 --- a/mypyc/codegen/emitfunc.py +++ b/mypyc/codegen/emitfunc.py @@ -727,7 +727,13 @@ def visit_get_element_ptr(self, op: GetElementPtr) -> None: def visit_load_address(self, op: LoadAddress) -> None: typ = op.type dest = self.reg(op) - src = self.reg(op.src) if isinstance(op.src, Register) else op.src + if isinstance(op.src, Register): + src = self.reg(op.src) + elif isinstance(op.src, LoadStatic): + prefix = self.PREFIX_MAP[op.src.namespace] + src = self.emitter.static_name(op.src.identifier, op.src.module_name, prefix) + else: + src = op.src self.emit_line(f"{dest} = ({typ._ctype})&{src};") def visit_keep_alive(self, op: KeepAlive) -> None: diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index adf24de235fff..412d36dd52d7b 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -1348,13 +1348,14 @@ class LoadAddress(RegisterOp): Attributes: type: Type of the loaded address(e.g. ptr/object_ptr) src: Source value (str for globals like 'PyList_Type', - Register for temporary values or locals) + Register for temporary values or locals, LoadStatic + for statics.) """ error_kind = ERR_NEVER is_borrowed = True - def __init__(self, type: RType, src: str | Register, line: int = -1) -> None: + def __init__(self, type: RType, src: str | Register | LoadStatic, line: int = -1) -> None: super().__init__(line) self.type = type self.src = src diff --git a/mypyc/ir/pprint.py b/mypyc/ir/pprint.py index 82e82913c9a67..4d10a91835cac 100644 --- a/mypyc/ir/pprint.py +++ b/mypyc/ir/pprint.py @@ -266,6 +266,11 @@ def visit_get_element_ptr(self, op: GetElementPtr) -> str: def visit_load_address(self, op: LoadAddress) -> str: if isinstance(op.src, Register): return self.format("%r = load_address %r", op, op.src) + elif isinstance(op.src, LoadStatic): + name = op.src.identifier + if op.src.module_name is not None: + name = f"{op.src.module_name}.{name}" + return self.format("%r = load_address %s :: %s", op, name, op.src.namespace) else: return self.format("%r = load_address %s", op, op.src) diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index 136da8b9ff989..16bcfbcca4e5a 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -187,6 +187,8 @@ def __init__( self.encapsulating_funcs = pbv.encapsulating_funcs self.nested_fitems = pbv.nested_funcs.keys() self.fdefs_to_decorators = pbv.funcs_to_decorators + self.module_import_groups = pbv.module_import_groups + self.singledispatch_impls = singledispatch_impls self.visitor = visitor diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index d41b532f9228e..0b63f897a74ea 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -1693,6 +1693,12 @@ def new_list_op(self, values: list[Value], line: int) -> Value: def new_set_op(self, values: list[Value], line: int) -> Value: return self.call_c(new_set_op, values, line) + def setup_rarray(self, item_type: RType, values: Sequence[Value]) -> Value: + """Declare and initialize a new RArray, returning its address.""" + array = Register(RArray(item_type, len(values))) + self.add(AssignMulti(array, list(values))) + return self.add(LoadAddress(c_pointer_rprimitive, array)) + def shortcircuit_helper( self, op: str, diff --git a/mypyc/irbuild/prebuildvisitor.py b/mypyc/irbuild/prebuildvisitor.py index d994539550021..e33f517e380f9 100644 --- a/mypyc/irbuild/prebuildvisitor.py +++ b/mypyc/irbuild/prebuildvisitor.py @@ -5,18 +5,20 @@ Expression, FuncDef, FuncItem, + Import, LambdaExpr, MemberExpr, MypyFile, NameExpr, + Node, SymbolNode, Var, ) -from mypy.traverser import TraverserVisitor +from mypy.traverser import ExtendedTraverserVisitor from mypyc.errors import Errors -class PreBuildVisitor(TraverserVisitor): +class PreBuildVisitor(ExtendedTraverserVisitor): """Mypy file AST visitor run before building the IR. This collects various things, including: @@ -26,6 +28,7 @@ class PreBuildVisitor(TraverserVisitor): * Find non-local variables (free variables) * Find property setters * Find decorators of functions + * Find module import groups The main IR build pass uses this information. """ @@ -68,10 +71,28 @@ def __init__( # Map function to indices of decorators to remove self.decorators_to_remove: dict[FuncDef, list[int]] = decorators_to_remove + # Map starting module import to import groups. Each group is a + # series of imports with nothing between. + self.module_import_groups: dict[Import, list[Import]] = {} + self._current_import_group: Import | None = None + self.errors: Errors = errors self.current_file: MypyFile = current_file + def visit(self, o: Node) -> bool: + if isinstance(o, Import): + if self._current_import_group is not None: + self.module_import_groups[self._current_import_group].append(o) + else: + self.module_import_groups[o] = [o] + self._current_import_group = o + # Don't recurse into the import's assignments. + return False + + self._current_import_group = None + return True + def visit_decorator(self, dec: Decorator) -> None: if dec.decorators: # Only add the function being decorated if there exist diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 45105fc4d8fb3..9f6f6237bae09 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -53,6 +53,7 @@ LoadAddress, LoadErrorValue, LoadLiteral, + LoadStatic, MethodCall, RaiseStandardError, Register, @@ -63,6 +64,7 @@ ) from mypyc.ir.rtypes import ( RInstance, + c_pyssize_t_rprimitive, exc_rtuple, is_tagged, none_rprimitive, @@ -100,6 +102,7 @@ check_stop_op, coro_op, import_from_many_op, + import_many_op, send_op, type_op, yield_from_except_op, @@ -220,32 +223,69 @@ def transform_operator_assignment_stmt(builder: IRBuilder, stmt: OperatorAssignm def transform_import(builder: IRBuilder, node: Import) -> None: if node.is_mypy_only: return - globals = builder.load_globals_dict() - for node_id, as_name in node.ids: - builder.gen_import(node_id, node.line) - - # Update the globals dict with the appropriate module: - # * For 'import foo.bar as baz' we add 'foo.bar' with the name 'baz' - # * For 'import foo.bar' we add 'foo' with the name 'foo' - # Typically we then ignore these entries and access things directly - # via the module static, but we will use the globals version for modules - # that mypy couldn't find, since it doesn't analyze module references - # from those properly. - - # TODO: Don't add local imports to the global namespace - - # Miscompiling imports inside of functions, like below in import from. - if as_name: - name = as_name - base = node_id - else: - base = name = node_id.split(".")[0] - obj = builder.get_module(base, node.line) + # Imports (not from imports!) are processed in an odd way so they can be + # table-driven and compact. Here's how it works: + # + # Import nodes are divided in groups (in the prebuild visitor). Each group + # consists of consecutive Import nodes: + # + # import mod <| group #1 + # import mod2 | + # + # def foo() -> None: + # import mod3 <- group #2 + # + # import mod4 <- group #3 + # + # Every time we encounter the first import of a group, build IR to call a + # helper function that will perform all of the group's imports in one go. + if node not in builder.module_import_groups: + return - builder.gen_method_call( - globals, "__setitem__", [builder.load_str(name), obj], result_type=None, line=node.line - ) + modules = [] + statics = [] + # To show the right line number on failure, we have to add the traceback + # entry within the helper function (which is admittedly ugly). To drive + # this, we'll need the line number corresponding to each import. + import_lines = [] + for import_node in builder.module_import_groups[node]: + for mod_id, as_name in import_node.ids: + builder.imports[mod_id] = None + import_lines.append(Integer(import_node.line, c_pyssize_t_rprimitive)) + + module_static = LoadStatic(object_rprimitive, mod_id, namespace=NAMESPACE_MODULE) + static_ptr = builder.add(LoadAddress(object_pointer_rprimitive, module_static)) + statics.append(static_ptr) + # TODO: Don't add local imports to the global namespace + # Update the globals dict with the appropriate module: + # * For 'import foo.bar as baz' we add 'foo.bar' with the name 'baz' + # * For 'import foo.bar' we add 'foo' with the name 'foo' + # Typically we then ignore these entries and access things directly + # via the module static, but we will use the globals version for + # modules that mypy couldn't find, since it doesn't analyze module + # references from those properly. + if as_name or "." not in mod_id: + globals_base = None + else: + globals_base = mod_id.split(".")[0] + modules.append((mod_id, as_name, globals_base)) + + static_array_ptr = builder.builder.setup_rarray(object_pointer_rprimitive, statics) + import_line_ptr = builder.builder.setup_rarray(c_pyssize_t_rprimitive, import_lines) + function = "" if builder.fn_info.name == "" else builder.fn_info.name + builder.call_c( + import_many_op, + [ + builder.add(LoadLiteral(tuple(modules), object_rprimitive)), + static_array_ptr, + builder.load_globals_dict(), + builder.load_str(builder.module_path), + builder.load_str(function), + import_line_ptr, + ], + NO_TRACEBACK_LINE_NO, + ) def transform_import_from(builder: IRBuilder, node: ImportFrom) -> None: diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 4ceacef3b4a48..7a3e16fe9d658 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -622,6 +622,8 @@ PyObject *CPy_Super(PyObject *builtins, PyObject *self); PyObject *CPy_CallReverseOpMethod(PyObject *left, PyObject *right, const char *op, _Py_Identifier *method); +bool CPyImport_ImportMany(PyObject *modules, CPyModule **statics[], PyObject *globals, + PyObject *tb_path, PyObject *tb_function, Py_ssize_t *tb_lines); PyObject *CPyImport_ImportFromMany(PyObject *mod_id, PyObject *names, PyObject *as_names, PyObject *globals); diff --git a/mypyc/lib-rt/misc_ops.c b/mypyc/lib-rt/misc_ops.c index 5a64d4390cfc0..4d64270a62e90 100644 --- a/mypyc/lib-rt/misc_ops.c +++ b/mypyc/lib-rt/misc_ops.c @@ -669,6 +669,68 @@ CPy_Super(PyObject *builtins, PyObject *self) { return result; } +static bool import_single(PyObject *mod_id, + PyObject *as_name, + PyObject **mod_static, + PyObject *globals_base, + PyObject *globals) { + if (*mod_static == Py_None) { + CPyModule *mod = PyImport_Import(mod_id); + if (mod == NULL) { + return false; + } + *mod_static = mod; + } + + if (as_name == Py_None) { + as_name = mod_id; + } + PyObject *globals_id, *globals_name; + if (globals_base == Py_None) { + globals_id = mod_id; + globals_name = as_name; + } else { + globals_id = globals_name = globals_base; + } + PyObject *mod_dict = PyImport_GetModuleDict(); + CPyModule *globals_mod = CPyDict_GetItem(mod_dict, globals_id); + if (globals_mod == NULL) { + return false; + } + int ret = CPyDict_SetItem(globals, globals_name, globals_mod); + Py_DECREF(globals_mod); + if (ret < 0) { + return false; + } + + return true; +} + +// Table-driven import helper. See transform_import() in irbuild for the details. +bool CPyImport_ImportMany(PyObject *modules, CPyModule **statics[], PyObject *globals, + PyObject *tb_path, PyObject *tb_function, Py_ssize_t *tb_lines) { + for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(modules); i++) { + PyObject *module = PyTuple_GET_ITEM(modules, i); + PyObject *mod_id = PyTuple_GET_ITEM(module, 0); + PyObject *as_name = PyTuple_GET_ITEM(module, 1); + PyObject *globals_base = PyTuple_GET_ITEM(module, 2); + + if (!import_single(mod_id, as_name, statics[i], globals_base, globals)) { + const char *path = PyUnicode_AsUTF8(tb_path); + if (path == NULL) { + path = ""; + } + const char *function = PyUnicode_AsUTF8(tb_function); + if (function == NULL) { + function = ""; + } + CPy_AddTraceback(path, function, tb_lines[i], globals); + return false; + } + } + return true; +} + // This helper function is a simplification of cpython/ceval.c/import_from() static PyObject *CPyImport_ImportFrom(PyObject *module, PyObject *package_name, PyObject *import_name, PyObject *as_name) { diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index d4d9d96de182b..0e04c0471b8df 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -7,6 +7,7 @@ bit_rprimitive, bool_rprimitive, c_int_rprimitive, + c_pointer_rprimitive, c_pyssize_t_rprimitive, dict_rprimitive, int_rprimitive, @@ -111,7 +112,7 @@ is_borrowed=True, ) -# Import a module +# Import a module (plain) import_op = custom_op( arg_types=[str_rprimitive], return_type=object_rprimitive, @@ -119,6 +120,21 @@ error_kind=ERR_MAGIC, ) +# Import helper op (handles globals/statics & can import multiple modules) +import_many_op = custom_op( + arg_types=[ + object_rprimitive, + c_pointer_rprimitive, + object_rprimitive, + object_rprimitive, + object_rprimitive, + c_pointer_rprimitive, + ], + return_type=bit_rprimitive, + c_function_name="CPyImport_ImportMany", + error_kind=ERR_FALSE, +) + # From import helper op import_from_many_op = custom_op( arg_types=[object_rprimitive, object_rprimitive, object_rprimitive, object_rprimitive], diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index 8573b05c0591d..46002ac1f1ce7 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -682,6 +682,95 @@ L0: r5 = unbox(int, r4) return r5 +[case testImport_toplevel] +import sys +import enum as enum2 +import collections.abc +import collections.abc as abc2 +_ = "filler" +import single +single.hello() + +[file single.py] +def hello() -> None: + print("hello, world") + +[out] +def __top_level__(): + r0, r1 :: object + r2 :: bit + r3 :: str + r4 :: object + r5, r6, r7, r8 :: object_ptr + r9 :: object_ptr[4] + r10 :: c_ptr + r11 :: native_int[4] + r12 :: c_ptr + r13 :: object + r14 :: dict + r15, r16 :: str + r17 :: bit + r18 :: str + r19 :: dict + r20 :: str + r21 :: int32 + r22 :: bit + r23 :: object_ptr + r24 :: object_ptr[1] + r25 :: c_ptr + r26 :: native_int[1] + r27 :: c_ptr + r28 :: object + r29 :: dict + r30, r31 :: str + r32 :: bit + r33 :: object + r34 :: str + r35, r36 :: object +L0: + r0 = builtins :: module + r1 = load_address _Py_NoneStruct + r2 = r0 != r1 + if r2 goto L2 else goto L1 :: bool +L1: + r3 = 'builtins' + r4 = PyImport_Import(r3) + builtins = r4 :: module +L2: + r5 = load_address sys :: module + r6 = load_address enum :: module + r7 = load_address collections.abc :: module + r8 = load_address collections.abc :: module + r9 = [r5, r6, r7, r8] + r10 = load_address r9 + r11 = [1, 2, 3, 4] + r12 = load_address r11 + r13 = (('sys', None, None), ('enum', 'enum2', None), ('collections.abc', None, 'collections'), ('collections.abc', 'abc2', None)) + r14 = __main__.globals :: static + r15 = 'main' + r16 = '' + r17 = CPyImport_ImportMany(r13, r10, r14, r15, r16, r12) + r18 = 'filler' + r19 = __main__.globals :: static + r20 = '_' + r21 = CPyDict_SetItem(r19, r20, r18) + r22 = r21 >= 0 :: signed + r23 = load_address single :: module + r24 = [r23] + r25 = load_address r24 + r26 = [6] + r27 = load_address r26 + r28 = (('single', None, None),) + r29 = __main__.globals :: static + r30 = 'main' + r31 = '' + r32 = CPyImport_ImportMany(r28, r25, r29, r30, r31, r27) + r33 = single :: module + r34 = 'hello' + r35 = CPyObject_GetAttr(r33, r34) + r36 = PyObject_CallFunctionObjArgs(r35, 0) + return 1 + [case testFromImport_toplevel] from testmodule import g, h from testmodule import h as two @@ -3388,47 +3477,39 @@ x = 1 [file p/m.py] [out] def f(): - r0 :: dict - r1, r2 :: object - r3 :: bit - r4 :: str + r0 :: object_ptr + r1 :: object_ptr[1] + r2 :: c_ptr + r3 :: native_int[1] + r4 :: c_ptr r5 :: object r6 :: dict - r7 :: str - r8 :: object - r9 :: str - r10 :: int32 - r11 :: bit - r12 :: dict + r7, r8 :: str + r9 :: bit + r10 :: dict + r11 :: str + r12 :: object r13 :: str r14 :: object - r15 :: str - r16 :: object - r17 :: int -L0: - r0 = __main__.globals :: static - r1 = p.m :: module - r2 = load_address _Py_NoneStruct - r3 = r1 != r2 - if r3 goto L2 else goto L1 :: bool -L1: - r4 = 'p.m' - r5 = PyImport_Import(r4) - p.m = r5 :: module -L2: - r6 = PyImport_GetModuleDict() - r7 = 'p' - r8 = CPyDict_GetItem(r6, r7) - r9 = 'p' - r10 = CPyDict_SetItem(r0, r9, r8) - r11 = r10 >= 0 :: signed - r12 = PyImport_GetModuleDict() - r13 = 'p' - r14 = CPyDict_GetItem(r12, r13) - r15 = 'x' - r16 = CPyObject_GetAttr(r14, r15) - r17 = unbox(int, r16) - return r17 + r15 :: int +L0: + r0 = load_address p.m :: module + r1 = [r0] + r2 = load_address r1 + r3 = [2] + r4 = load_address r3 + r5 = (('p.m', None, 'p'),) + r6 = __main__.globals :: static + r7 = 'main' + r8 = 'f' + r9 = CPyImport_ImportMany(r5, r2, r6, r7, r8, r4) + r10 = PyImport_GetModuleDict() + r11 = 'p' + r12 = CPyDict_GetItem(r10, r11) + r13 = 'x' + r14 = CPyObject_GetAttr(r12, r13) + r15 = unbox(int, r14) + return r15 [case testIsinstanceBool] def f(x: object) -> bool: diff --git a/mypyc/test-data/run-imports.test b/mypyc/test-data/run-imports.test index c6d5bdb3d8649..fd289458e786a 100644 --- a/mypyc/test-data/run-imports.test +++ b/mypyc/test-data/run-imports.test @@ -18,10 +18,12 @@ def test_import_submodule_within_function() -> None: import pkg.mod assert pkg.x == 1 assert pkg.mod.y == 2 + assert "pkg.mod" not in globals(), "the root module should be in globals!" def test_import_as_submodule_within_function() -> None: import pkg.mod as mm assert mm.y == 2 + assert "pkg.mod" not in globals(), "the root module should be in globals!" # TODO: Don't add local imports to globals() # @@ -192,3 +194,44 @@ a.x = 10 x = 20 [file driver.py] import native + +[case testLazyImport] +import shared + +def do_import() -> None: + import a + +assert shared.counter == 0 +do_import() +assert shared.counter == 1 + +[file a.py] +import shared +shared.counter += 1 + +[file shared.py] +counter = 0 + +[case testDelayedImport] +import a +print("inbetween") +import b + +[file a.py] +print("first") + +[file b.py] +print("last") + +[out] +first +inbetween +last + +[case testImportErrorLineNumber] +try: + import enum + import dataclasses, missing # type: ignore[import] +except ImportError as e: + line = e.__traceback__.tb_lineno # type: ignore[attr-defined] + assert line == 3, f"traceback's line number is {line}, expected 3"