Skip to content

Commit

Permalink
Correctly access partitioned tensors with index math
Browse files Browse the repository at this point in the history
  • Loading branch information
nandeeka committed Aug 23, 2024
1 parent 856490e commit 5e4f221
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 31 deletions.
28 changes: 22 additions & 6 deletions teaal/trans/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"""
from copy import deepcopy

from sympy import Symbol # type: ignore
from sympy import Add, Symbol # type: ignore
from typing import List, Optional

from teaal.hifiber import *
Expand Down Expand Up @@ -148,19 +148,35 @@ def __build_access(self, rank: str) -> Expression:
part_ir = self.program.get_partitioning()
root, suffix = part_ir.split_rank_name(rank.upper())

# This is not the innermost rank
if len(suffix) > 0 and suffix != "0" and suffix[-1] != "I":
return EVar(rank)

# If this rank is the result of flattening, then build the access
# as a tuple of the constituent ranks
if part_ir.is_flattened(rank.upper()):
flat_ranks = self.program.get_loop_order().get_iter_ranks(rank.upper())
return ETuple([EVar(frank.lower()) for frank in flat_ranks])

# Otherwise, this is the innermost rank; so translate
sexpr = self.program.get_coord_math().get_trans(root.lower())

# This is not the innermost rank, so we only care about the term with
# the root
if len(suffix) > 0 and suffix != "0" and suffix[-1] != "I":
# Extract the relevant term
opt_loop_rank = part_ir.partition_rank((rank.upper(),))
assert opt_loop_rank is not None
loop_rank = opt_loop_rank[0]

loop_root, loop_suffix = part_ir.split_rank_name(loop_rank)

terms = [
t for t in Add.make_args(sexpr) if Symbol(
loop_root.lower()) in t.free_symbols]
assert len(terms) == 1

term = terms[0].subs(
Symbol(
loop_root.lower()), Symbol(
loop_rank.lower()))
return CoordAccess.build_expr(term)

# Now, we need to replace the roots with their dynamic names
for symbol in sexpr.atoms(Symbol):
# Fix dynamic partitioning variable name
Expand Down
46 changes: 21 additions & 25 deletions tests/integration/demo.yaml
Original file line number Diff line number Diff line change
@@ -1,30 +1,26 @@
einsum:
declaration: # Ranks are listed alphabetically in this section
TS: [T, P1, P0, E]
ND: [P1, P0]
NTS: [T, P1, P0, E]
PF: [P1, P0, E, C]
OF: [P1, E]
MS: [P1, P0, E]
PS: [P1, P0, E, C]
TOP: [P1, E, C, P0]
declaration:
I: [B, C, H, W]
F: [C, M, R, S]
O: [B, M, P, Q]
expressions:
- ND[p1, p0] = TS[t, p1, p0, e]
- NTS[t, p1, p0, e] = TS[t, p1, p0, e] * ND[p1, p0]
- PF[p1, p0, e, c] = NTS[t, p1, p0, e]
- OF[p1, e] = PF[p1, p0, e, c]
- MS[p1, p0, e] = NTS[t, p1, p0, e]
- PS[p1, p0, e, c] = MS[p1, p0, e] * OF[p1, e]
- TOP[p1, e, c, p0] = PF[p1, p0, e, c] + PS[p1, p0, e, c]
- O[b, m, p, q] = I[b, c, 4*p+r, 4*q+s]*F[c, m, r, s]
mapping:
rank-order:
PF: [P1, E, P0, C]
PS: [P1, E, P0, C]
TOP: [P1, E, C, P0]
I: [B, C, H, W]
F: [M, C, R, S]
O: [B, M, P, Q]
partitioning:
O:
M:
- uniform_shape(32)
- uniform_shape(16)
P:
- uniform_shape(P0)
H: [follow(P)]
loop-order:
ND: [T, P1, P0, E]
PF: [T, P1, E, P0, C]
MS: [T, P1, P0, E]
OF: [P1, E, P0, C]
PS: [P1, E, P0, C]
TOP: [P1, E, C, P0]
O: [C, M2, B, M1, P1, P0, R, Q, S, M0]
spacetime:
O:
space: [P0]
time: [C, M2, B, M1, P1, R, Q, S, M0]
39 changes: 39 additions & 0 deletions tests/trans/test_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,30 @@ def create_conv():
return Program(Einsum.from_str(yaml), Mapping.from_str(yaml))


def create_conv_step():
yaml = """
einsum:
declaration:
F: [S]
I: [W]
O: [Q]
expressions:
- O[q] = I[2 * q + s] * F[s]
mapping:
partitioning:
O:
Q: [uniform_shape(10)]
W: [follow(Q)]
loop-order:
O: [Q1, Q0, S]
spacetime:
O:
space: []
time: [Q1, Q0, S]
"""
return Program(Einsum.from_str(yaml), Mapping.from_str(yaml))


def test_create_canvas():
program = create_spacetime()
program.add_einsum(0)
Expand Down Expand Up @@ -282,6 +306,21 @@ def test_add_activity_conv():
assert canvas.add_activity().gen(0) == hifiber


def test_add_activity_conv_step():
program = create_conv_step()
program.add_einsum(0)
part_ir = program.get_partitioning()

for tensor in program.get_equation().get_tensors():
program.apply_all_partitioning(tensor)

canvas = Canvas(program)
canvas.create_canvas()

hifiber = "canvas.addActivity((2 * q1, s + 2 * q0), (s,), (q1, q0), spacetime=((), (q1_pos, q0_pos, s_pos)))"
assert canvas.add_activity().gen(0) == hifiber


def test_display_canvas_no_canvas():
program = create_default()
program.add_einsum(0)
Expand Down

0 comments on commit 5e4f221

Please sign in to comment.