From 688f9267b7a5dfb7b83786e79c90995cb2a773c3 Mon Sep 17 00:00:00 2001 From: Jiasheng Zhang Date: Tue, 20 Jul 2021 17:32:32 +0800 Subject: [PATCH] [Lang] Support very basic python string.format() now (#2552) * [Lang] Support very basic python string.format() now * fixed bug * Auto Format * fixed bug * Auto Format * Improved method * Update python/taichi/lang/transformer.py Co-authored-by: Taichi Gardener Co-authored-by: ljcc0930 --- python/taichi/lang/impl.py | 37 ++++++++++++++++++++++++++++++- python/taichi/lang/ops.py | 7 +++--- python/taichi/lang/transformer.py | 5 +++++ 3 files changed, 44 insertions(+), 5 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index e2173944feace..819b0e5aab532 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -517,7 +517,13 @@ def vars2entries(vars): if hasattr(var, '__ti_repr__'): res = var.__ti_repr__() elif isinstance(var, (list, tuple)): - res = list_ti_repr(var) + res = var + # If the first element is '__ti_format__', this list is the result of ti_format. + if len(var) > 0 and isinstance( + var[0], str) and var[0] == '__ti_format__': + res = var[1:] + else: + res = list_ti_repr(var) else: yield var continue @@ -552,6 +558,35 @@ def fused_string(entries): _ti_core.create_print(contentries) +@taichi_scope +def ti_format(*args): + content = args[0] + mixed = args[1:] + new_mixed = [] + args = [] + for x in mixed: + if isinstance(x, ti.Expr): + new_mixed.append('{}') + args.append(x) + else: + new_mixed.append(x) + + try: + content = content.format(*new_mixed) + except ValueError: + print('Number formatting is not supported with Taichi fields') + exit(1) + res = content.split('{}') + assert len(res) == len( + args + ) + 1, 'Number of args is different from number of positions provided in string' + + for i in range(len(args)): + res.insert(i * 2 + 1, args[i]) + res.insert(0, '__ti_format__') + return res + + @taichi_scope def ti_assert(cond, msg, extra_args): # Mostly a wrapper to help us convert from Expr (defined in Python) to diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index 6ea9fa84ac552..4ceac55792eb0 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -526,7 +526,6 @@ def external_func_call(func, args=[], outputs=[]): def asm(source, inputs=[], outputs=[]): - _ti_core.insert_external_func_call(0, source, make_expr_group(inputs), make_expr_group(outputs)) @@ -567,11 +566,11 @@ def rescale_index(a, b, I): """ assert isinstance(a, Expr) and a.is_global(), \ - f"first arguement must be a field" + f"first arguement must be a field" assert isinstance(b, Expr) and b.is_global(), \ - f"second arguement must be a field" + f"second arguement must be a field" assert isinstance(I, matrix.Matrix) and not I.is_global(), \ - f"third arguement must be a grouped index" + f"third arguement must be a grouped index" Ib = I.copy() for n in range(min(I.n, min(len(a.shape), len(b.shape)))): if a.shape[n] > b.shape[n]: diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index aebe0fbe418de..b256b8e926c66 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -626,6 +626,11 @@ def visit_Call(self, node): if not ASTResolver.resolve_to(node.func, ti.static, globals()): # Do not apply the generic visitor if the function called is ti.static self.generic_visit(node) + if isinstance(node.func, ast.Attribute): + attr_name = node.func.attr + if attr_name == 'format': + node.args.insert(0, node.func.value) + node.func = self.parse_expr('ti.ti_format') if isinstance(node.func, ast.Name): func_name = node.func.id if func_name == 'print':