Skip to content

Commit

Permalink
[Lang] Support very basic python string.format() now (#2552)
Browse files Browse the repository at this point in the history
* [Lang] Support very basic python string.format() now

* fixed bug

* Auto Format

* fixed bug

* Auto Format

* Improved method

* Update python/taichi/lang/transformer.py

Co-authored-by: Taichi Gardener <[email protected]>
Co-authored-by: ljcc0930 <[email protected]>
  • Loading branch information
3 people authored Jul 20, 2021
1 parent e04ae46 commit 688f926
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 5 deletions.
37 changes: 36 additions & 1 deletion python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,13 @@ def vars2entries(vars):
if hasattr(var, '__ti_repr__'):
res = var.__ti_repr__()
elif isinstance(var, (list, tuple)):
res = list_ti_repr(var)
res = var
# If the first element is '__ti_format__', this list is the result of ti_format.
if len(var) > 0 and isinstance(
var[0], str) and var[0] == '__ti_format__':
res = var[1:]
else:
res = list_ti_repr(var)
else:
yield var
continue
Expand Down Expand Up @@ -552,6 +558,35 @@ def fused_string(entries):
_ti_core.create_print(contentries)


@taichi_scope
def ti_format(*args):
content = args[0]
mixed = args[1:]
new_mixed = []
args = []
for x in mixed:
if isinstance(x, ti.Expr):
new_mixed.append('{}')
args.append(x)
else:
new_mixed.append(x)

try:
content = content.format(*new_mixed)
except ValueError:
print('Number formatting is not supported with Taichi fields')
exit(1)
res = content.split('{}')
assert len(res) == len(
args
) + 1, 'Number of args is different from number of positions provided in string'

for i in range(len(args)):
res.insert(i * 2 + 1, args[i])
res.insert(0, '__ti_format__')
return res


@taichi_scope
def ti_assert(cond, msg, extra_args):
# Mostly a wrapper to help us convert from Expr (defined in Python) to
Expand Down
7 changes: 3 additions & 4 deletions python/taichi/lang/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,6 @@ def external_func_call(func, args=[], outputs=[]):


def asm(source, inputs=[], outputs=[]):

_ti_core.insert_external_func_call(0, source, make_expr_group(inputs),
make_expr_group(outputs))

Expand Down Expand Up @@ -567,11 +566,11 @@ def rescale_index(a, b, I):
"""
assert isinstance(a, Expr) and a.is_global(), \
f"first arguement must be a field"
f"first arguement must be a field"
assert isinstance(b, Expr) and b.is_global(), \
f"second arguement must be a field"
f"second arguement must be a field"
assert isinstance(I, matrix.Matrix) and not I.is_global(), \
f"third arguement must be a grouped index"
f"third arguement must be a grouped index"
Ib = I.copy()
for n in range(min(I.n, min(len(a.shape), len(b.shape)))):
if a.shape[n] > b.shape[n]:
Expand Down
5 changes: 5 additions & 0 deletions python/taichi/lang/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,11 @@ def visit_Call(self, node):
if not ASTResolver.resolve_to(node.func, ti.static, globals()):
# Do not apply the generic visitor if the function called is ti.static
self.generic_visit(node)
if isinstance(node.func, ast.Attribute):
attr_name = node.func.attr
if attr_name == 'format':
node.args.insert(0, node.func.value)
node.func = self.parse_expr('ti.ti_format')
if isinstance(node.func, ast.Name):
func_name = node.func.id
if func_name == 'print':
Expand Down

0 comments on commit 688f926

Please sign in to comment.