diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 9bd85463785ad..a88cf34302ce7 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -290,17 +290,8 @@ def __init__(self, pos, needed, provided): def _get_global_vars(func): - # Discussions: https://github.com/taichi-dev/taichi/issues/282 - global_vars = copy.copy(func.__globals__) - - freevar_names = func.__code__.co_freevars - closure = func.__closure__ - if closure: - freevar_values = list(map(lambda x: x.cell_contents, closure)) - for name, value in zip(freevar_names, freevar_values): - global_vars[name] = value - - return global_vars + closure_vars = inspect.getclosurevars(func) + return {**closure_vars.globals, **closure_vars.nonlocals} class Kernel: diff --git a/python/taichi/lang/stmt_builder.py b/python/taichi/lang/stmt_builder.py index 345d7dc3a978b..4708d788df4d4 100644 --- a/python/taichi/lang/stmt_builder.py +++ b/python/taichi/lang/stmt_builder.py @@ -595,11 +595,12 @@ def transform_as_kernel(): if impl.get_runtime().experimental_real_function: transform_as_kernel() else: - # Transform as func (all parameters passed by value) + # Transform as force-inlined func arg_decls = [] for i, arg in enumerate(args.args): - # Directly pass in template arguments, - # such as class instances ("self"), fields, SNodes, etc. + # Remove annotations because they are not used. + args.args[i].annotation = None + # Template arguments are passed by reference. if isinstance(ctx.func.argument_annotations[i], ti.template): ctx.create_variable(ctx.func.argument_names[i])