Skip to content

Commit

Permalink
Use <<= if there is no reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
nandeeka committed Oct 30, 2024
1 parent ea99815 commit 68c53f1
Show file tree
Hide file tree
Showing 10 changed files with 31 additions and 19 deletions.
11 changes: 10 additions & 1 deletion teaal/trans/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,11 @@ def make_update(self) -> Statement:

# Create the final statement
out_name = self.program.get_equation().get_output().root_name().lower() + "_ref"
return SIAssign(AVar(out_name), OAdd(), sum_)

if self.__no_reduction():
return SIAssign(AVar(out_name), OLtLt(), sum_)
else:
return SIAssign(AVar(out_name), OAdd(), sum_)

@staticmethod
def __add_operator(
Expand All @@ -288,6 +292,11 @@ def __add_enumerate(self, rank: str, expr: Expression) -> Expression:
expr = EFunc("enumerate", [AJust(expr)])
return expr

def __no_reduction(self) -> bool:
out_ranks = self.program.get_equation().get_output().get_ranks()
all_ranks = self.program.get_loop_order().get_ranks()
return out_ranks == all_ranks

def __need_enumerate(self, rank: str) -> bool:
"""
Returns True if new need to enumerate over this rank
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/example7.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
b_m = B_MN.getRoot()
for m, (z_n, (_, a_n, b_n)) in z_m << (a_m | b_m):
for n, (z_ref, (_, a_val, b_val)) in z_n << (a_n | b_n):
z_ref += a_val + b_val
z_ref <<= a_val + b_val
Z_MN = Tensor(rank_ids=["M", "N"], name="Z")
z_m = Z_MN.getRoot()
a_m = A_MN.getRoot()
b_m = B_MN.getRoot()
for m, (z_n, (a_n, b_n)) in z_m << (a_m & b_m):
for n, (z_ref, (a_val, b_val)) in z_n << (a_n & b_n):
z_ref += a_val * b_val
z_ref <<= a_val * b_val
2 changes: 1 addition & 1 deletion tests/integration/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
c_m = C_MN.getRoot()
for m, (z_n, (_, t1_n, c_n)) in z_m << (t1_m | c_m):
for n, (z_ref, (_, t1_val, c_val)) in z_n << (t1_n | c_n):
z_ref += a * t1_val + b * c_val
z_ref <<= a * t1_val + b * c_val
2 changes: 1 addition & 1 deletion tests/integration/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
t1_m = T1_M.getRoot()
c_m = C_M.getRoot()
for m, (z_ref, (_, t1_val, c_val)) in z_m << (t1_m | c_m):
z_ref += a * t1_val + b * c_val
z_ref <<= a * t1_val + b * c_val
2 changes: 1 addition & 1 deletion tests/integration/nrm_sq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
for b, (t_i, v_i) in t_b << v_b:
for i, (t_j, v_j) in t_i << v_i:
for j, (t_ref, v_val) in t_j << v_j:
t_ref += v_val
t_ref <<= v_val
Q_ = Tensor(rank_ids=[], name="Q")
q_ref = Q_.getRoot()
v_a = V_ABIJ.getRoot()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/outerprod.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
b_n = B_N.getRoot()
for m, (z_n, a_val) in z_m << a_m:
for n, (z_ref, b_val) in z_n << b_n:
z_ref += a_val * b_val
z_ref <<= a_val * b_val
2 changes: 1 addition & 1 deletion tests/integration/sddmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
t1_m = T1_MN.getRoot()
for m, (z_n, (c_n, t1_n)) in z_m << (c_m & t1_m):
for n, (z_ref, (c_val, t1_val)) in z_n << (c_n & t1_n):
z_ref += c_val * t1_val
z_ref <<= c_val * t1_val
5 changes: 4 additions & 1 deletion tests/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def read_hifiber(filename):


def test_integration():
errors = []
for test_name in test_names:
filename = 'tests/integration/' + test_name

Expand All @@ -43,4 +44,6 @@ def test_integration():
hifiber = read_hifiber(filename + ".py")
if output != hifiber:
print(output)
assert output == hifiber, test_name + " integration test failed!"
errors.append(test_name)

assert not errors, "Integration tests " + str(errors) + " failed!"
6 changes: 3 additions & 3 deletions tests/trans/test_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,18 +635,18 @@ def test_make_update():
def test_make_update_vars():
program = make_other("A[i] = b * c * d", "")
eqn = Equation(program, None)
stmt = "a_ref += b * c * d"
stmt = "a_ref <<= b * c * d"
assert eqn.make_update().gen(depth=0) == stmt


def test_make_update_mult_terms():
program = make_other("A[i] = b * B[i] + c * C[i] + d * D[i]", "")
eqn = Equation(program, None)
stmt = "a_ref += b * b_val + c * c_val + d * d_val"
stmt = "a_ref <<= b * b_val + c * c_val + d * d_val"
assert eqn.make_update().gen(depth=0) == stmt


def test_make_update_take():
_, eqn = make_take()
stmt = "z_ref += b"
stmt = "z_ref <<= b"
assert eqn.make_update().gen(depth=0) == stmt
14 changes: 7 additions & 7 deletions tests/trans/test_hifiber.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def test_translate_no_loops():
"tests/integration/test_translate_no_loops.yaml")
hifiber = "A_ = Tensor(rank_ids=[], name=\"A\")\n" + \
"a_ref = A_.getRoot()\n" + \
"a_ref += b"
"a_ref <<= b"
assert str(HiFiber(einsum, mapping)) == hifiber


Expand All @@ -32,7 +32,7 @@ def test_translate_defaults():
"c_m = C_MN.getRoot()\n" + \
"for m, (z_n, (_, t1_n, c_n)) in z_m << (t1_m | c_m):\n" + \
" for n, (z_ref, (_, t1_val, c_val)) in z_n << (t1_n | c_n):\n" + \
" z_ref += t1_val + c_val"
" z_ref <<= t1_val + c_val"

assert str(HiFiber(einsum, mapping)) == hifiber

Expand Down Expand Up @@ -88,7 +88,7 @@ def test_translate_specified():
" for n1, (z_m0, (_, t1_m0, c_m0)) in z_n1 << (t1_n1 | c_n1):\n" + \
" for m0, (z_n0, (_, t1_n0, c_n0)) in z_m0 << (t1_m0 | c_m0):\n" + \
" for n0, (z_ref, (_, t1_val, c_val)) in z_n0 << (t1_n0 | c_n0):\n" + \
" z_ref += t1_val + c_val\n" + \
" z_ref <<= t1_val + c_val\n" + \
"tmp14 = Z_M2N2M1N1M0N0\n" + \
"tmp15 = tmp14.swizzleRanks(rank_ids=[\"N2\", \"N1\", \"N0\", \"M2\", \"M1\", \"M0\"])\n" + \
"tmp16 = tmp15.mergeRanks(depth=3, levels=2, coord_style=\"absolute\")\n" + \
Expand Down Expand Up @@ -143,7 +143,7 @@ def test_translate_specified():
" for n1, (z_m0, (_, t1_m0, c_m0)) in z_n1 << (t1_n1 | c_n1):\n" + \
" for m0, (z_n0, (_, t1_n0, c_n0)) in z_m0 << (t1_m0 | c_m0):\n" + \
" for n0, (z_ref, (_, t1_val, c_val)) in z_n0 << (t1_n0 | c_n0):\n" + \
" z_ref += t1_val + c_val\n" + \
" z_ref <<= t1_val + c_val\n" + \
"tmp14 = Z_M2N2M1N1M0N0\n" + \
"tmp15 = tmp14.swizzleRanks(rank_ids=[\"N2\", \"N1\", \"N0\", \"M2\", \"M1\", \"M0\"])\n" + \
"tmp16 = tmp15.mergeRanks(depth=0, levels=2, coord_style=\"absolute\")\n" + \
Expand Down Expand Up @@ -198,7 +198,7 @@ def test_translate_specified():
" for n1, (z_m0, (_, t1_m0, c_m0)) in z_n1 << (t1_n1 | c_n1):\n" + \
" for m0, (z_n0, (_, t1_n0, c_n0)) in z_m0 << (t1_m0 | c_m0):\n" + \
" for n0, (z_ref, (_, t1_val, c_val)) in z_n0 << (t1_n0 | c_n0):\n" + \
" z_ref += t1_val + c_val\n" + \
" z_ref <<= t1_val + c_val\n" + \
"tmp14 = Z_M2N2M1N1M0N0\n" + \
"tmp15 = tmp14.swizzleRanks(rank_ids=[\"N2\", \"N1\", \"N0\", \"M2\", \"M1\", \"M0\"])\n" + \
"tmp16 = tmp15.mergeRanks(depth=0, levels=2, coord_style=\"absolute\")\n" + \
Expand Down Expand Up @@ -253,7 +253,7 @@ def test_translate_specified():
" for n1, (z_m0, (_, t1_m0, c_m0)) in z_n1 << (t1_n1 | c_n1):\n" + \
" for m0, (z_n0, (_, t1_n0, c_n0)) in z_m0 << (t1_m0 | c_m0):\n" + \
" for n0, (z_ref, (_, t1_val, c_val)) in z_n0 << (t1_n0 | c_n0):\n" + \
" z_ref += t1_val + c_val\n" + \
" z_ref <<= t1_val + c_val\n" + \
"tmp14 = Z_M2N2M1N1M0N0\n" + \
"tmp15 = tmp14.swizzleRanks(rank_ids=[\"N2\", \"N1\", \"N0\", \"M2\", \"M1\", \"M0\"])\n" + \
"tmp16 = tmp15.mergeRanks(depth=3, levels=2, coord_style=\"absolute\")\n" + \
Expand Down Expand Up @@ -381,7 +381,7 @@ def test_hifiber_index_math_no_halo():
" else:\n" + \
" m0_end = M\n" + \
" for m0, (z_ref, a_val) in z_m0 << a_k0.project(trans_fn=lambda k0: 1 / 2 * k0, interval=(m0_start, m0_end)).prune(trans_fn=lambda i, c, p: c % 1 == 0):\n" + \
" z_ref += a_val\n" + \
" z_ref <<= a_val\n" + \
"tmp3 = Z_M2M1M0\n" + \
"tmp4 = tmp3.mergeRanks(depth=0, levels=2, coord_style=\"absolute\")\n" + \
"tmp4.setRankIds(rank_ids=[\"M\"])\n" + \
Expand Down

0 comments on commit 68c53f1

Please sign in to comment.