Skip to content

Commit

Permalink
Merge pull request #4 from JiacongSun/master
Browse files Browse the repository at this point in the history
A Small Update on SpatialMappingConversionStage and SearchUnusedMemoryStage
  • Loading branch information
JiacongSun authored Nov 13, 2023
2 parents c34cfb0 + 9449772 commit a953be5
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 15 deletions.
26 changes: 20 additions & 6 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 30 additions & 3 deletions zigzag/classes/stages/SearchUnusedMemoryStage.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,16 @@ def update_top_mem_level(self):
if (
const_operand in served_operands
): # identify the top weight mem level
# We need to check if the current mem serve all oa dims, otherwise we will not decrease
# the mem_update_weight.
# The reason is if the current mem not serve all oa dims, the mapping will impact the memory
# utilization, so solely comparing with total memory size will be incorrect.
mem_serve_all_oa_dims = self.check_if_mem_serve_all_oa_dims(
mem, self.accelerator
)
if (
curr_mem_level < self.mem_update_weight
): # mem_update_weight is bigger than the top weight mem level
) and mem_serve_all_oa_dims: # mem_update_weight is bigger than the top weight mem level
self.mem_update_weight = curr_mem_level
break
else: ## node (layer) that is not a branch starting node or a branch final node
Expand Down Expand Up @@ -402,9 +409,18 @@ def update_top_mem_level(self):
self.update_IO_mem_level(
curr_id, output_operand, curr_mem_level
) # update output mem level
# For weight, we need to check if the current mem serve all oa dims, otherwise we will not
# decrease the mem_update_weight.
# The reason is if the current mem not serve all oa dims, the mapping will impact the memory
# utilization, so solely comparing with total memory size will be incorrect.
mem_serve_all_oa_dims = self.check_if_mem_serve_all_oa_dims(
mem, self.accelerator
)
if (
curr_mem_level < self.mem_update_weight
) and mem_serve_weight: # update weight mem level
(curr_mem_level < self.mem_update_weight)
and mem_serve_all_oa_dims
and mem_serve_weight
): # update weight mem level
self.mem_update_weight = curr_mem_level
## [OPTIONAL CHECK] assert check if there is -1 value in mem_update_list
## [NOTE] Until here, if there is still -1 value in mem_update_list, it means the size of top mem level for IO is not big enough.
Expand All @@ -414,6 +430,17 @@ def update_top_mem_level(self):
list(operand_dict.values())[0] >= 0
), "SearchUnusedMemoryStage fisnishes abnormally, there are still layers with top mem levels not figured out."

def check_if_mem_serve_all_oa_dims(self, mem, accelerator):
# check if mem serve all hardare dimensions
core = accelerator.cores[0]
operational_array = core.operational_array
oa_dim_nb = len(operational_array.dimensions)
mem_served_oa_dim_nb = len(mem.served_dimensions)
if mem_served_oa_dim_nb == oa_dim_nb:
return True
else:
return False

def update_mem_level_for_loading_data(self):
"""
[OPTIONAL FUNCTION] This is an optional function.
Expand Down
41 changes: 35 additions & 6 deletions zigzag/classes/stages/SpatialMappingConversionStage.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,21 @@ def convert_user_spatial_mapping(self, user_spatial_mapping):
layer=self.layer,
accelerator=self.accelerator,
)
try:
SpatialMapping(spatial_mapping_dict=spatial_mapping_dict, layer_node=self.layer), SpatialMapping(
spatial_mapping_dict=spatial_mapping_dict_int, layer_node=self.layer
)
except:
pass

return SpatialMapping(
spatial_mapping_dict=spatial_mapping_dict, layer_node=self.layer
), SpatialMapping(
spatial_mapping_dict=spatial_mapping_dict_int, layer_node=self.layer
)
try:
return SpatialMapping(
spatial_mapping_dict=spatial_mapping_dict, layer_node=self.layer
), SpatialMapping(
spatial_mapping_dict=spatial_mapping_dict_int, layer_node=self.layer
)
except:
pass

def generate_limited_user_spatial_mapping(
self,
Expand Down Expand Up @@ -363,8 +372,28 @@ def generate_spatial_mapping_dict(self, user_spatial_mapping, layer, accelerator
# After we have gone through the memory levels, if there are still user-defined dimensions
# present, add them as the top level. Otherwise add an empty list to make arch levels correct:
# because first list we added was the operational array level.

# We will merge together if the top memory level is serving multiple oa dims
# and there are layer dims existing on multiple oa dims.
top_level_spatial_mapping_dict = {}
for (dim_name, spatial_loop) in user_sm_copy.items():
if self.is_nested_tuple(spatial_loop): # mix sm loop
for sub_spatial_loop in spatial_loop:
spatial_loop_dim = sub_spatial_loop[0]
spatial_loop_size = sub_spatial_loop[1]
if spatial_loop_dim not in top_level_spatial_mapping_dict.keys():
top_level_spatial_mapping_dict[spatial_loop_dim] = spatial_loop_size
else:
top_level_spatial_mapping_dict[spatial_loop_dim] *= spatial_loop_size
else:
spatial_loop_dim = spatial_loop[0]
spatial_loop_size = spatial_loop[1]
if spatial_loop_dim not in top_level_spatial_mapping_dict.keys():
top_level_spatial_mapping_dict[spatial_loop_dim] = spatial_loop_size
else:
top_level_spatial_mapping_dict[spatial_loop_dim] *= spatial_loop_size
top_level_spatial_mapping = [
spatial_loop for (dim_name, spatial_loop) in user_sm_copy.items()
(layer_dim, layer_size) for (layer_dim, layer_size) in top_level_spatial_mapping_dict.items()
]
spatial_mapping_dict[layer_op].append(top_level_spatial_mapping)
return spatial_mapping_dict
Expand Down

0 comments on commit a953be5

Please sign in to comment.