From 68c53f1bdde43bca8949dab1c97871def1696317 Mon Sep 17 00:00:00 2001 From: Nandeeka Nayak Date: Wed, 30 Oct 2024 13:52:18 -0500 Subject: [PATCH] Use <<= if there is no reduction --- teaal/trans/equation.py | 11 ++++++++++- tests/integration/example7.py | 4 ++-- tests/integration/gemm.py | 2 +- tests/integration/gemv.py | 2 +- tests/integration/nrm_sq.py | 2 +- tests/integration/outerprod.py | 2 +- tests/integration/sddmm.py | 2 +- tests/integration/test_integration.py | 5 ++++- tests/trans/test_equation.py | 6 +++--- tests/trans/test_hifiber.py | 14 +++++++------- 10 files changed, 31 insertions(+), 19 deletions(-) diff --git a/teaal/trans/equation.py b/teaal/trans/equation.py index 786f819..96a067d 100644 --- a/teaal/trans/equation.py +++ b/teaal/trans/equation.py @@ -263,7 +263,11 @@ def make_update(self) -> Statement: # Create the final statement out_name = self.program.get_equation().get_output().root_name().lower() + "_ref" - return SIAssign(AVar(out_name), OAdd(), sum_) + + if self.__no_reduction(): + return SIAssign(AVar(out_name), OLtLt(), sum_) + else: + return SIAssign(AVar(out_name), OAdd(), sum_) @staticmethod def __add_operator( @@ -288,6 +292,11 @@ def __add_enumerate(self, rank: str, expr: Expression) -> Expression: expr = EFunc("enumerate", [AJust(expr)]) return expr + def __no_reduction(self) -> bool: + out_ranks = self.program.get_equation().get_output().get_ranks() + all_ranks = self.program.get_loop_order().get_ranks() + return out_ranks == all_ranks + def __need_enumerate(self, rank: str) -> bool: """ Returns True if new need to enumerate over this rank diff --git a/tests/integration/example7.py b/tests/integration/example7.py index 1264b61..df97cd2 100644 --- a/tests/integration/example7.py +++ b/tests/integration/example7.py @@ -4,11 +4,11 @@ b_m = B_MN.getRoot() for m, (z_n, (_, a_n, b_n)) in z_m << (a_m | b_m): for n, (z_ref, (_, a_val, b_val)) in z_n << (a_n | b_n): - z_ref += a_val + b_val + z_ref <<= a_val + b_val Z_MN = Tensor(rank_ids=["M", "N"], name="Z") z_m = Z_MN.getRoot() a_m = A_MN.getRoot() b_m = B_MN.getRoot() for m, (z_n, (a_n, b_n)) in z_m << (a_m & b_m): for n, (z_ref, (a_val, b_val)) in z_n << (a_n & b_n): - z_ref += a_val * b_val \ No newline at end of file + z_ref <<= a_val * b_val \ No newline at end of file diff --git a/tests/integration/gemm.py b/tests/integration/gemm.py index c4d1c43..1e4c00d 100644 --- a/tests/integration/gemm.py +++ b/tests/integration/gemm.py @@ -14,4 +14,4 @@ c_m = C_MN.getRoot() for m, (z_n, (_, t1_n, c_n)) in z_m << (t1_m | c_m): for n, (z_ref, (_, t1_val, c_val)) in z_n << (t1_n | c_n): - z_ref += a * t1_val + b * c_val \ No newline at end of file + z_ref <<= a * t1_val + b * c_val \ No newline at end of file diff --git a/tests/integration/gemv.py b/tests/integration/gemv.py index 287dcda..301ea33 100644 --- a/tests/integration/gemv.py +++ b/tests/integration/gemv.py @@ -11,4 +11,4 @@ t1_m = T1_M.getRoot() c_m = C_M.getRoot() for m, (z_ref, (_, t1_val, c_val)) in z_m << (t1_m | c_m): - z_ref += a * t1_val + b * c_val \ No newline at end of file + z_ref <<= a * t1_val + b * c_val \ No newline at end of file diff --git a/tests/integration/nrm_sq.py b/tests/integration/nrm_sq.py index e0d9ded..c1c974f 100644 --- a/tests/integration/nrm_sq.py +++ b/tests/integration/nrm_sq.py @@ -5,7 +5,7 @@ for b, (t_i, v_i) in t_b << v_b: for i, (t_j, v_j) in t_i << v_i: for j, (t_ref, v_val) in t_j << v_j: - t_ref += v_val + t_ref <<= v_val Q_ = Tensor(rank_ids=[], name="Q") q_ref = Q_.getRoot() v_a = V_ABIJ.getRoot() diff --git a/tests/integration/outerprod.py b/tests/integration/outerprod.py index 4c6d6ba..61452da 100644 --- a/tests/integration/outerprod.py +++ b/tests/integration/outerprod.py @@ -4,4 +4,4 @@ b_n = B_N.getRoot() for m, (z_n, a_val) in z_m << a_m: for n, (z_ref, b_val) in z_n << b_n: - z_ref += a_val * b_val \ No newline at end of file + z_ref <<= a_val * b_val \ No newline at end of file diff --git a/tests/integration/sddmm.py b/tests/integration/sddmm.py index 3a656a0..18daf08 100644 --- a/tests/integration/sddmm.py +++ b/tests/integration/sddmm.py @@ -13,4 +13,4 @@ t1_m = T1_MN.getRoot() for m, (z_n, (c_n, t1_n)) in z_m << (c_m & t1_m): for n, (z_ref, (c_val, t1_val)) in z_n << (c_n & t1_n): - z_ref += c_val * t1_val \ No newline at end of file + z_ref <<= c_val * t1_val \ No newline at end of file diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index 6059872..fdc8510 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -33,6 +33,7 @@ def read_hifiber(filename): def test_integration(): + errors = [] for test_name in test_names: filename = 'tests/integration/' + test_name @@ -43,4 +44,6 @@ def test_integration(): hifiber = read_hifiber(filename + ".py") if output != hifiber: print(output) - assert output == hifiber, test_name + " integration test failed!" \ No newline at end of file + errors.append(test_name) + + assert not errors, "Integration tests " + str(errors) + " failed!" \ No newline at end of file diff --git a/tests/trans/test_equation.py b/tests/trans/test_equation.py index d6c1774..fa1d9e1 100644 --- a/tests/trans/test_equation.py +++ b/tests/trans/test_equation.py @@ -635,18 +635,18 @@ def test_make_update(): def test_make_update_vars(): program = make_other("A[i] = b * c * d", "") eqn = Equation(program, None) - stmt = "a_ref += b * c * d" + stmt = "a_ref <<= b * c * d" assert eqn.make_update().gen(depth=0) == stmt def test_make_update_mult_terms(): program = make_other("A[i] = b * B[i] + c * C[i] + d * D[i]", "") eqn = Equation(program, None) - stmt = "a_ref += b * b_val + c * c_val + d * d_val" + stmt = "a_ref <<= b * b_val + c * c_val + d * d_val" assert eqn.make_update().gen(depth=0) == stmt def test_make_update_take(): _, eqn = make_take() - stmt = "z_ref += b" + stmt = "z_ref <<= b" assert eqn.make_update().gen(depth=0) == stmt diff --git a/tests/trans/test_hifiber.py b/tests/trans/test_hifiber.py index 7f8e449..dd69c48 100644 --- a/tests/trans/test_hifiber.py +++ b/tests/trans/test_hifiber.py @@ -8,7 +8,7 @@ def test_translate_no_loops(): "tests/integration/test_translate_no_loops.yaml") hifiber = "A_ = Tensor(rank_ids=[], name=\"A\")\n" + \ "a_ref = A_.getRoot()\n" + \ - "a_ref += b" + "a_ref <<= b" assert str(HiFiber(einsum, mapping)) == hifiber @@ -32,7 +32,7 @@ def test_translate_defaults(): "c_m = C_MN.getRoot()\n" + \ "for m, (z_n, (_, t1_n, c_n)) in z_m << (t1_m | c_m):\n" + \ " for n, (z_ref, (_, t1_val, c_val)) in z_n << (t1_n | c_n):\n" + \ - " z_ref += t1_val + c_val" + " z_ref <<= t1_val + c_val" assert str(HiFiber(einsum, mapping)) == hifiber @@ -88,7 +88,7 @@ def test_translate_specified(): " for n1, (z_m0, (_, t1_m0, c_m0)) in z_n1 << (t1_n1 | c_n1):\n" + \ " for m0, (z_n0, (_, t1_n0, c_n0)) in z_m0 << (t1_m0 | c_m0):\n" + \ " for n0, (z_ref, (_, t1_val, c_val)) in z_n0 << (t1_n0 | c_n0):\n" + \ - " z_ref += t1_val + c_val\n" + \ + " z_ref <<= t1_val + c_val\n" + \ "tmp14 = Z_M2N2M1N1M0N0\n" + \ "tmp15 = tmp14.swizzleRanks(rank_ids=[\"N2\", \"N1\", \"N0\", \"M2\", \"M1\", \"M0\"])\n" + \ "tmp16 = tmp15.mergeRanks(depth=3, levels=2, coord_style=\"absolute\")\n" + \ @@ -143,7 +143,7 @@ def test_translate_specified(): " for n1, (z_m0, (_, t1_m0, c_m0)) in z_n1 << (t1_n1 | c_n1):\n" + \ " for m0, (z_n0, (_, t1_n0, c_n0)) in z_m0 << (t1_m0 | c_m0):\n" + \ " for n0, (z_ref, (_, t1_val, c_val)) in z_n0 << (t1_n0 | c_n0):\n" + \ - " z_ref += t1_val + c_val\n" + \ + " z_ref <<= t1_val + c_val\n" + \ "tmp14 = Z_M2N2M1N1M0N0\n" + \ "tmp15 = tmp14.swizzleRanks(rank_ids=[\"N2\", \"N1\", \"N0\", \"M2\", \"M1\", \"M0\"])\n" + \ "tmp16 = tmp15.mergeRanks(depth=0, levels=2, coord_style=\"absolute\")\n" + \ @@ -198,7 +198,7 @@ def test_translate_specified(): " for n1, (z_m0, (_, t1_m0, c_m0)) in z_n1 << (t1_n1 | c_n1):\n" + \ " for m0, (z_n0, (_, t1_n0, c_n0)) in z_m0 << (t1_m0 | c_m0):\n" + \ " for n0, (z_ref, (_, t1_val, c_val)) in z_n0 << (t1_n0 | c_n0):\n" + \ - " z_ref += t1_val + c_val\n" + \ + " z_ref <<= t1_val + c_val\n" + \ "tmp14 = Z_M2N2M1N1M0N0\n" + \ "tmp15 = tmp14.swizzleRanks(rank_ids=[\"N2\", \"N1\", \"N0\", \"M2\", \"M1\", \"M0\"])\n" + \ "tmp16 = tmp15.mergeRanks(depth=0, levels=2, coord_style=\"absolute\")\n" + \ @@ -253,7 +253,7 @@ def test_translate_specified(): " for n1, (z_m0, (_, t1_m0, c_m0)) in z_n1 << (t1_n1 | c_n1):\n" + \ " for m0, (z_n0, (_, t1_n0, c_n0)) in z_m0 << (t1_m0 | c_m0):\n" + \ " for n0, (z_ref, (_, t1_val, c_val)) in z_n0 << (t1_n0 | c_n0):\n" + \ - " z_ref += t1_val + c_val\n" + \ + " z_ref <<= t1_val + c_val\n" + \ "tmp14 = Z_M2N2M1N1M0N0\n" + \ "tmp15 = tmp14.swizzleRanks(rank_ids=[\"N2\", \"N1\", \"N0\", \"M2\", \"M1\", \"M0\"])\n" + \ "tmp16 = tmp15.mergeRanks(depth=3, levels=2, coord_style=\"absolute\")\n" + \ @@ -381,7 +381,7 @@ def test_hifiber_index_math_no_halo(): " else:\n" + \ " m0_end = M\n" + \ " for m0, (z_ref, a_val) in z_m0 << a_k0.project(trans_fn=lambda k0: 1 / 2 * k0, interval=(m0_start, m0_end)).prune(trans_fn=lambda i, c, p: c % 1 == 0):\n" + \ - " z_ref += a_val\n" + \ + " z_ref <<= a_val\n" + \ "tmp3 = Z_M2M1M0\n" + \ "tmp4 = tmp3.mergeRanks(depth=0, levels=2, coord_style=\"absolute\")\n" + \ "tmp4.setRankIds(rank_ids=[\"M\"])\n" + \