Skip to content

Commit

Permalink
Move intersect counts to an element of a dictionary (so that we can a…
Browse files Browse the repository at this point in the history
…lso record time)
  • Loading branch information
nandeeka committed Oct 26, 2023
1 parent fc7ceff commit ff8c108
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 27 deletions.
13 changes: 7 additions & 6 deletions teaal/trans/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,10 @@ def __build_intersections(self) -> Statement:
for intersector in self.metrics.get_hardware().get_components(einsum,
IntersectorComponent):
isect_name = intersector.get_name()
metrics_isect = AAccess(metrics_einsum, EString(isect_name))
block.add(SAssign(metrics_isect, EInt(0)))
block.add(SAssign(AAccess(metrics_einsum, EString(isect_name)), EDict({})))
metrics_isect = EAccess(metrics_einsum, EString(isect_name))
metrics_isect_op = AAccess(metrics_isect, EString("intersect"))
block.add(SAssign(metrics_isect_op, EInt(0)))

for binding in intersector.get_bindings()[einsum]:
isects = EMethod(
Expand All @@ -556,15 +558,14 @@ def __build_intersections(self) -> Statement:
binding["rank"]),
"getNumIntersects",
[])
block.add(SIAssign(metrics_isect, OAdd(), isects))
block.add(SIAssign(metrics_isect_op, OAdd(), isects))

# op_freq = cycles / s * ops / cycle
op_freq = self.metrics.get_hardware().get_frequency(einsum) * \
intersector.get_num_instances()
metrics_isect_expr = EAccess(metrics_einsum, EString(isect_name))
time = EBinOp(metrics_isect_expr, ODiv(), EInt(op_freq))
time = EBinOp(EAccess(metrics_isect, EString("intersect")), ODiv(), EInt(op_freq))

metrics_time = AAccess(metrics_isect_expr, EString("time"))
metrics_time = AAccess(metrics_isect, EString("time"))
block.add(SAssign(metrics_time, time))
self.fusion.add_component(einsum, intersector.get_name())

Expand Down
49 changes: 28 additions & 21 deletions tests/trans/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,10 @@ def test_dump_gamma_T():
"metrics[\"T\"][\"MainMemory\"][\"A\"][\"read\"] = 0\n" + \
"metrics[\"T\"][\"MainMemory\"][\"A\"][\"read\"] += traffic[0][\"A\"][\"read\"]\n" + \
"metrics[\"T\"][\"MainMemory\"][\"time\"] = (metrics[\"T\"][\"MainMemory\"][\"A\"][\"read\"] + metrics[\"T\"][\"MainMemory\"][\"B\"][\"read\"]) / 1099511627776\n" + \
"metrics[\"T\"][\"Intersect\"] = 0\n" + \
"metrics[\"T\"][\"Intersect\"] += Intersect_K.getNumIntersects()\n" + \
"metrics[\"T\"][\"Intersect\"][\"time\"] = metrics[\"T\"][\"Intersect\"] / 32000000000"
"metrics[\"T\"][\"Intersect\"] = {}\n" + \
"metrics[\"T\"][\"Intersect\"][\"intersect\"] = 0\n" + \
"metrics[\"T\"][\"Intersect\"][\"intersect\"] += Intersect_K.getNumIntersects()\n" + \
"metrics[\"T\"][\"Intersect\"][\"time\"] = metrics[\"T\"][\"Intersect\"][\"intersect\"] / 32000000000"

assert collector.dump().gen(0) == hifiber

Expand Down Expand Up @@ -426,15 +427,18 @@ def test_dump_extensor():
"metrics[\"Z\"][\"FPAdd\"] = {}\n" + \
"metrics[\"Z\"][\"FPAdd\"][\"add\"] = Metrics.dump()[\"Compute\"][\"payload_add\"]\n" + \
"metrics[\"Z\"][\"FPAdd\"][\"time\"] = metrics[\"Z\"][\"FPAdd\"][\"add\"] / 128000000000\n" + \
"metrics[\"Z\"][\"K2Intersect\"] = 0\n" + \
"metrics[\"Z\"][\"K2Intersect\"] += K2Intersect_K2.getNumIntersects()\n" + \
"metrics[\"Z\"][\"K2Intersect\"][\"time\"] = metrics[\"Z\"][\"K2Intersect\"] / 1000000000\n" + \
"metrics[\"Z\"][\"K1Intersect\"] = 0\n" + \
"metrics[\"Z\"][\"K1Intersect\"] += K1Intersect_K1.getNumIntersects()\n" + \
"metrics[\"Z\"][\"K1Intersect\"][\"time\"] = metrics[\"Z\"][\"K1Intersect\"] / 1000000000\n" + \
"metrics[\"Z\"][\"K0Intersection\"] = 0\n" + \
"metrics[\"Z\"][\"K0Intersection\"] += K0Intersection_K0.getNumIntersects()\n" + \
"metrics[\"Z\"][\"K0Intersection\"][\"time\"] = metrics[\"Z\"][\"K0Intersection\"] / 128000000000\n" + \
"metrics[\"Z\"][\"K2Intersect\"] = {}\n" + \
"metrics[\"Z\"][\"K2Intersect\"][\"intersect\"] = 0\n" + \
"metrics[\"Z\"][\"K2Intersect\"][\"intersect\"] += K2Intersect_K2.getNumIntersects()\n" + \
"metrics[\"Z\"][\"K2Intersect\"][\"time\"] = metrics[\"Z\"][\"K2Intersect\"][\"intersect\"] / 1000000000\n" + \
"metrics[\"Z\"][\"K1Intersect\"] = {}\n" + \
"metrics[\"Z\"][\"K1Intersect\"][\"intersect\"] = 0\n" + \
"metrics[\"Z\"][\"K1Intersect\"][\"intersect\"] += K1Intersect_K1.getNumIntersects()\n" + \
"metrics[\"Z\"][\"K1Intersect\"][\"time\"] = metrics[\"Z\"][\"K1Intersect\"][\"intersect\"] / 1000000000\n" + \
"metrics[\"Z\"][\"K0Intersection\"] = {}\n" + \
"metrics[\"Z\"][\"K0Intersection\"][\"intersect\"] = 0\n" + \
"metrics[\"Z\"][\"K0Intersection\"][\"intersect\"] += K0Intersection_K0.getNumIntersects()\n" + \
"metrics[\"Z\"][\"K0Intersection\"][\"time\"] = metrics[\"Z\"][\"K0Intersection\"][\"intersect\"] / 128000000000\n" + \
"metrics[\"blocks\"] = [[\"Z\"]]\n" + \
"metrics[\"time\"] = max(metrics[\"Z\"][\"FPAdd\"][\"time\"], metrics[\"Z\"][\"FPMul\"][\"time\"], metrics[\"Z\"][\"K0Intersection\"][\"time\"], metrics[\"Z\"][\"K1Intersect\"][\"time\"], metrics[\"Z\"][\"K2Intersect\"][\"time\"], metrics[\"Z\"][\"MainMemory\"][\"time\"])"

Expand Down Expand Up @@ -488,15 +492,18 @@ def test_dump_extensor_energy():
"metrics[\"Z\"][\"FPAdd\"] = {}\n" + \
"metrics[\"Z\"][\"FPAdd\"][\"add\"] = Metrics.dump()[\"Compute\"][\"payload_add\"]\n" + \
"metrics[\"Z\"][\"FPAdd\"][\"time\"] = metrics[\"Z\"][\"FPAdd\"][\"add\"] / 128000000000\n" + \
"metrics[\"Z\"][\"K2Intersect\"] = 0\n" + \
"metrics[\"Z\"][\"K2Intersect\"] += K2Intersect_K2.getNumIntersects()\n" + \
"metrics[\"Z\"][\"K2Intersect\"][\"time\"] = metrics[\"Z\"][\"K2Intersect\"] / 1000000000\n" + \
"metrics[\"Z\"][\"K1Intersect\"] = 0\n" + \
"metrics[\"Z\"][\"K1Intersect\"] += K1Intersect_K1.getNumIntersects()\n" + \
"metrics[\"Z\"][\"K1Intersect\"][\"time\"] = metrics[\"Z\"][\"K1Intersect\"] / 1000000000\n" + \
"metrics[\"Z\"][\"K0Intersection\"] = 0\n" + \
"metrics[\"Z\"][\"K0Intersection\"] += K0Intersection_K0.getNumIntersects()\n" + \
"metrics[\"Z\"][\"K0Intersection\"][\"time\"] = metrics[\"Z\"][\"K0Intersection\"] / 128000000000\n" + \
"metrics[\"Z\"][\"K2Intersect\"] = {}\n" + \
"metrics[\"Z\"][\"K2Intersect\"][\"intersect\"] = 0\n" + \
"metrics[\"Z\"][\"K2Intersect\"][\"intersect\"] += K2Intersect_K2.getNumIntersects()\n" + \
"metrics[\"Z\"][\"K2Intersect\"][\"time\"] = metrics[\"Z\"][\"K2Intersect\"][\"intersect\"] / 1000000000\n" + \
"metrics[\"Z\"][\"K1Intersect\"] = {}\n" + \
"metrics[\"Z\"][\"K1Intersect\"][\"intersect\"] = 0\n" + \
"metrics[\"Z\"][\"K1Intersect\"][\"intersect\"] += K1Intersect_K1.getNumIntersects()\n" + \
"metrics[\"Z\"][\"K1Intersect\"][\"time\"] = metrics[\"Z\"][\"K1Intersect\"][\"intersect\"] / 1000000000\n" + \
"metrics[\"Z\"][\"K0Intersection\"] = {}\n" + \
"metrics[\"Z\"][\"K0Intersection\"][\"intersect\"] = 0\n" + \
"metrics[\"Z\"][\"K0Intersection\"][\"intersect\"] += K0Intersection_K0.getNumIntersects()\n" + \
"metrics[\"Z\"][\"K0Intersection\"][\"time\"] = metrics[\"Z\"][\"K0Intersection\"][\"intersect\"] / 128000000000\n" + \
"metrics[\"Z\"][\"TopSequencer\"] = {}\n" + \
"metrics[\"Z\"][\"TopSequencer\"][\"N2\"] = Compute.numIters(\"tmp/extensor_energy-N2-iter.csv\")\n" + \
"metrics[\"Z\"][\"TopSequencer\"][\"K2\"] = Compute.numIters(\"tmp/extensor_energy-K2-iter.csv\")\n" + \
Expand Down

0 comments on commit ff8c108

Please sign in to comment.