Skip to content

Commit

Permalink
Add some visualization of access count (in addition to existing acces…
Browse files Browse the repository at this point in the history
…s order)
  • Loading branch information
hunhoffe committed Oct 25, 2024
1 parent b08eebf commit 06716e1
Showing 1 changed file with 124 additions and 65 deletions.
189 changes: 124 additions & 65 deletions python/helpers/tensortiler/tensortiler2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,12 @@ def __init__(
offset: int,
sizes: list[int],
strides: list[int],
transfer_len: int | None = None,
repeats: bool = False,
):
self.tensor_height = tensor_height
self.tensor_width = tensor_width
self.offset = offset
self.sizes = sizes
self.strides = strides
self.repeats = repeats
self.transfer_len = transfer_len

@property
def dimensions(self) -> list[tuple[int, int]]:
Expand All @@ -34,7 +30,7 @@ def visualize(
file_path: str | None = None,
show_plot: bool = True,
) -> None:
TensorTiler2D.generate_access_graph(
TensorTiler2D.generate_access_graphs(
self.tensor_height,
self.tensor_width,
self.sizes,
Expand All @@ -53,7 +49,6 @@ def access_order(self) -> np.ndarray:
self.sizes,
self.strides,
offset=self.offset,
allow_repeat=self.repeats,
)

def __str__(self) -> str:
Expand All @@ -73,13 +68,9 @@ def __init__(
strides: list[int],
offset_fn: Callable[[int], int],
num_steps: int,
transfer_len: int | None = None,
repeats: bool = False,
):
self._num_steps = num_steps
self._current_step = 0
self._transfer_len = transfer_len
self._repeats = repeats

self._tensor_height = tensor_height
self._tensor_width = tensor_width
Expand All @@ -101,8 +92,6 @@ def __next__(self) -> TensorTile:
self._offset_fn(step),
self._sizes,
self._strides,
self._transfer_len,
self._repeats,
)


Expand Down Expand Up @@ -178,7 +167,7 @@ def __init__(
else:
"""
This is the case that *should* always represent a correct/valid
transformation (according to my modelling using visualization tools).
transformation (according to my testing using visualization tools).
It should work even with the special cases above.
Expand Down Expand Up @@ -309,20 +298,17 @@ def calc_offset(iter_num):
iter_strides,
offset_fn=calc_offset,
num_steps=steps,
transfer_len=chunk_width
* chunk_height
* self._tile_height
* self._tile_width,
)

def __str__(self) -> str:
return f"sizes={self._sizes}, strides={self._strides}"

@classmethod
def _generate_access_graph_from_tensor(
def _generate_access_graphs_from_tensor(
cls,
access_order_tensor: np.ndarray,
title: str = "Access Order",
access_count_tensor: np.ndarray | None,
title: str = "Access Order and Access Count",
show_arrows: bool = True,
show_numbers: bool = False,
file_path: str | None = None,
Expand All @@ -337,41 +323,21 @@ def _generate_access_graph_from_tensor(
"You must pip install matplotlib in order to render access graphs"
)

# In inches, this is a little hacky
# should maybe be defined by the size of the tensor e.g., how many elem per inch
matplotlib.rcParams["figure.figsize"] = [10, 7]

_fig, ax = plt.subplots()
_heatmap = ax.pcolor(access_order_tensor, cmap="gnuplot2")
fig, (ax_order, ax_count) = plt.subplots(1, 2)
_access_heatmap = ax_order.pcolor(access_order_tensor, cmap="gnuplot2")

# Thanks to https://stackoverflow.com/questions/14406214/moving-x-axis-to-the-top-of-a-plot-in-matplotlib
# put the major ticks at the middle of each cell, (0, 0) in upper left corner
ax.set_xticks(np.arange(access_order_tensor.shape[1]) + 0.5, minor=False)
ax.set_yticks(np.arange(access_order_tensor.shape[0]) + 0.5, minor=False)
ax.invert_yaxis()
ax.xaxis.tick_top()
ax.set_xticklabels(
ax_order.set_xticks(np.arange(access_order_tensor.shape[1]) + 0.5, minor=False)
ax_order.set_yticks(np.arange(access_order_tensor.shape[0]) + 0.5, minor=False)
ax_order.invert_yaxis()
ax_order.xaxis.tick_top()
ax_order.set_xticklabels(
np.arange(0, access_order_tensor.shape[1]), minor=False, rotation="vertical"
)
ax.set_yticklabels(np.arange(0, access_order_tensor.shape[0]), minor=False)
plt.title(title)

# Add numbers to the plot
if show_numbers:
# Thanks to https://stackoverflow.com/questions/37719304/python-imshow-set-certain-value-to-defined-color
# Thanks to tmdavison answer here https://stackoverflow.com/a/40890587/7871710
for i in range(access_order_tensor.shape[0]):
for j in range(access_order_tensor.shape[1]):
c = access_order_tensor[i, j]
if c != -1:
ax.text(
j + 0.45,
i + 0.45,
str(c),
path_effects=[
pe.withStroke(linewidth=3, foreground="white")
],
)
ax_order.set_yticklabels(
np.arange(0, access_order_tensor.shape[0]), minor=False
)

# Add arrows to show access order
if show_arrows:
Expand All @@ -383,7 +349,7 @@ def _generate_access_graph_from_tensor(
for i in range(len(order_dict) - 1):
y1, x1 = order_dict[i]
y2, x2 = order_dict[i + 1]
ax.arrow(
ax_order.arrow(
x1 + 0.5,
y1 + 0.5,
x2 - x1,
Expand All @@ -394,7 +360,60 @@ def _generate_access_graph_from_tensor(
overhang=0.2,
path_effects=[pe.withStroke(linewidth=3, foreground="white")],
)
ax_order.set_title("Access Order")

if not (access_count_tensor is None):
max_count = np.max(access_count_tensor)

_count_heatmap = ax_count.pcolor(access_count_tensor, cmap="gnuplot2")
# Thanks to https://stackoverflow.com/questions/14406214/moving-x-axis-to-the-top-of-a-plot-in-matplotlib
# put the major ticks at the middle of each cell, (0, 0) in upper left corner
ax_count.set_xticks(
np.arange(access_count_tensor.shape[1]) + 0.5, minor=False
)
ax_count.set_yticks(
np.arange(access_count_tensor.shape[0]) + 0.5, minor=False
)
ax_count.invert_yaxis()
ax_count.xaxis.tick_top()
ax_count.set_xticklabels(
np.arange(0, access_count_tensor.shape[1]),
minor=False,
rotation="vertical",
)
ax_count.set_yticklabels(
np.arange(0, access_count_tensor.shape[0]), minor=False
)
ax_count.set_title("Access Counts")

# Add numbers to the plot
if show_numbers:
# Thanks to https://stackoverflow.com/questions/37719304/python-imshow-set-certain-value-to-defined-color
# Thanks to tmdavison answer here https://stackoverflow.com/a/40890587/7871710
for i in range(access_order_tensor.shape[0]):
for j in range(access_order_tensor.shape[1]):
c = access_order_tensor[i, j]
if c != -1:
ax_order.text(
j + 0.45,
i + 0.45,
str(c),
path_effects=[
pe.withStroke(linewidth=3, foreground="white")
],
)
if not (access_count_tensor is None):
c = access_count_tensor[i, j]
ax_count.text(
j + 0.45,
i + 0.45,
str(c),
path_effects=[
pe.withStroke(linewidth=3, foreground="white")
],
)

# plt.title(title)
if show_plot:
plt.show()
if file_path:
Expand All @@ -406,7 +425,7 @@ def _generate_access_graph_from_tensor(
plt.savefig(file_path)

@classmethod
def get_access_order_tensor(
def get_access_tensors(
cls,
tensor_height: int,
tensor_width: int,
Expand All @@ -416,7 +435,7 @@ def get_access_order_tensor(
tile_width: int | None = None,
offset: int = 0,
allow_repeat: bool = False,
) -> np.ndarray:
) -> (np.ndarray, np.ndarray | None):
assert tensor_height > 0 and tensor_width > 0, "Tensor dimensions must be > 0"
assert len(sizes) == 4, "Sizes should be a list of size 4"
assert len(strides) == 4, "Strides should be a list of size 4"
Expand All @@ -425,10 +444,19 @@ def get_access_order_tensor(
and (tile_height > 0 and tile_width > 0)
), "Tile Height and Tile Width should both be specified, or neither specified"

# Generate access order map
# Create access order map
access_order_tensor = np.full(
(tensor_height * tensor_width,), -1, dtype=cls.DTYPE
)

# Create access count map (if repeat allowed)
if allow_repeat:
access_count_tensor = np.full(
(tensor_height * tensor_width,), 0, dtype=cls.DTYPE
)
else:
access_count_tensor = None

access_count = 0
for i in range(sizes[0]):
for j in range(sizes[1]):
Expand All @@ -445,16 +473,24 @@ def get_access_order_tensor(
assert (
access_order_tensor[access_idx] == -1
), f"Attempted to access index={access_idx} twice."
else:
access_count_tensor[access_idx] += 1
access_order_tensor[access_idx] = access_count
access_count += 1
assert access_count <= np.prod(
access_order_tensor.shape
), f"Access pattern has too many elements (expected max {np.prod(access_order_tensor.shape)}, got {access_count})"
if not allow_repeat:
assert access_count <= np.prod(
access_order_tensor.shape
), f"Access pattern has too many elements (expected max {np.prod(access_order_tensor.shape)}, got {access_count})"

access_order_tensor = access_order_tensor.reshape((tensor_height, tensor_width))
return access_order_tensor
if allow_repeat:
access_count_tensor = access_count_tensor.reshape(
(tensor_height, tensor_width)
)
return access_order_tensor, access_count_tensor

@classmethod
def generate_access_graph(
def generate_access_graphs(
cls,
tensor_height: int,
tensor_width: int,
Expand All @@ -467,15 +503,17 @@ def generate_access_graph(
show_numbers: bool = False,
file_path: str | None = None,
show_plot: bool = True,
allow_repeat: bool = False,
):
access_order_tensor = cls.get_access_order_tensor(
access_order_tensor, access_count_tensor = cls.get_access_tensors(
tensor_height,
tensor_width,
sizes,
strides,
tile_height=tile_height,
tile_width=tile_width,
offset=offset,
allow_repeat=allow_repeat,
)

# Show a graph for a single tile
Expand All @@ -484,17 +522,23 @@ def generate_access_graph(
tile_file_path = file_path + ".tile.png"
else:
tile_file_path = None
cls._generate_access_graph_from_tensor(
cls._generate_access_graphs_from_tensor(
access_order_tensor[0:tile_height, 0:tile_width],
(
None
if not allow_repeat
else access_count_tensor[0:tile_height, 0:tile_width]
),
title="Per-Tile Access Order",
show_arrows=show_arrows,
show_numbers=show_numbers,
file_path=tile_file_path,
show_plot=show_plot,
)

cls._generate_access_graph_from_tensor(
cls._generate_access_graphs_from_tensor(
access_order_tensor,
access_count_tensor,
show_arrows=show_arrows,
show_numbers=show_numbers,
file_path=file_path,
Expand All @@ -508,10 +552,11 @@ def visualize(
show_numbers: bool = False,
file_path: str | None = None,
show_plot: bool = True,
allow_repeat: bool = False,
) -> None:
tile_height = self._tile_height if show_tile else None
tile_width = self._tile_width if show_tile else None
self.generate_access_graph(
self.generate_access_graphs(
self._tensor_height,
self._tensor_width,
self._sizes,
Expand All @@ -522,15 +567,29 @@ def visualize(
show_numbers=show_numbers,
file_path=file_path,
show_plot=show_plot,
allow_repeat=allow_repeat,
)

def access_order(self) -> np.ndarray:
def access_order(self, allow_repeat: bool = False) -> np.ndarray:
# Call class method
return self.get_access_order_tensor(
return self.get_access_tensors(
self._tensor_height,
self._tensor_width,
self._sizes,
self._strides,
self._tile_height,
self._tile_width,
)
allow_repeat=allow_repeat,
)[0]

def access_count(self) -> np.ndarray:
# Call class method
return self.get_access_tensors(
self._tensor_height,
self._tensor_width,
self._sizes,
self._strides,
self._tile_height,
self._tile_width,
allow_repeat=True,
)[1]

0 comments on commit 06716e1

Please sign in to comment.