Skip to content

Commit

Permalink
Optimizations: fix indexing of restrict views adding extra operations (
Browse files Browse the repository at this point in the history
  • Loading branch information
NaderAlAwar authored Sep 4, 2024
1 parent 791a46b commit bed9349
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
40 changes: 27 additions & 13 deletions pykokkos/core/optimizations/restrict_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from pykokkos.core import cppast
from pykokkos.interface import Subview, View, ViewType, Trait
from pykokkos.interface import Layout, Subview, View, ViewType, Trait


def may_share_memory(a, b) -> bool:
Expand Down Expand Up @@ -59,7 +59,7 @@ def may_share_memory(a, b) -> bool:

return False

def get_restrict_views(views: Dict[str, ViewType]) -> Tuple[Set[str], str]:
def get_restrict_views(views: Dict[str, ViewType]) -> Tuple[Dict[str, ViewType], str]:
"""
Identify views that do not alias each other to apply the restrict
keyword to
Expand All @@ -69,7 +69,7 @@ def get_restrict_views(views: Dict[str, ViewType]) -> Tuple[Set[str], str]:
"""

# Map from base view id() to a list of views that alias that view
base_view_ids: Dict[int, Set[str]] = {}
base_view_ids: Dict[int, Dict[str, ViewType]] = {}
# Map from view name to the xp_array it has
xp_arrays: Dict[str, Any] = {}

Expand Down Expand Up @@ -98,11 +98,11 @@ def get_restrict_views(views: Dict[str, ViewType]) -> Tuple[Set[str], str]:

view_id: int = id(base_view)
if view_id in base_view_ids:
base_view_ids[view_id].add(view_name)
base_view_ids[view_id][view_name] = view
else:
base_view_ids[view_id] = {view_name}
base_view_ids[view_id] = {view_name: view}

restricted_views: Set[str] = set()
restricted_views: Dict[str, ViewType] = {}
for view_id, view_set in base_view_ids.items():
# if len(view_set) == 1:
restricted_views.update(view_set)
Expand All @@ -127,7 +127,9 @@ def get_restrict_views(views: Dict[str, ViewType]) -> Tuple[Set[str], str]:
aliasing_arrays.add(name)
aliasing_arrays.add(other_name)

restricted_views -= aliasing_arrays
for arr in aliasing_arrays:
restricted_views.pop(arr, None)

restricted_signature: str = hashlib.md5("".join(sorted(restricted_views)).encode()).hexdigest()

return restricted_views, restricted_signature
Expand Down Expand Up @@ -219,7 +221,7 @@ def define_restrict_function(functor: cppast.RecordDecl, operation: str, workuni
return method


def index_restrict_view(name: cppast.DeclRefExpr, indices: List[cppast.Expr]) -> cppast.ArraySubscriptExpr:
def index_restrict_view(name: cppast.DeclRefExpr, indices: List[cppast.Expr], view: ViewType) -> cppast.ArraySubscriptExpr:
"""
Get the indexing operation of a particular view with a given list
of indices
Expand All @@ -229,12 +231,24 @@ def index_restrict_view(name: cppast.DeclRefExpr, indices: List[cppast.Expr]) ->
"""

restrict_name: cppast.DeclRefExpr = get_restrict_ptr_name(name)
full_index: cppast.Expr

# Subviews could be strided
if isinstance(view, Subview) or view.rank() > 2:
full_index = cppast.BinaryOperator(indices[0], get_stride_name(name, 0), cppast.BinaryOperatorKind.Mul)
for i, index in enumerate(indices[1:]):
current_stride: cppast.DeclRefExpr = get_stride_name(name, i + 1) # Add one since we did zero before the loop
current_mul = cppast.BinaryOperator(index, current_stride, cppast.BinaryOperatorKind.Mul)
full_index = cppast.BinaryOperator(current_mul, full_index, cppast.BinaryOperatorKind.Add)

full_index: cppast.Expr = cppast.BinaryOperator(indices[0], get_stride_name(name, 0), cppast.BinaryOperatorKind.Mul)
for i, index in enumerate(indices[1:]):
current_stride: cppast.DeclRefExpr = get_stride_name(name, i + 1) # Add one since we did zero before the loop
current_mul = cppast.BinaryOperator(index, current_stride, cppast.BinaryOperatorKind.Mul)
full_index = cppast.BinaryOperator(current_mul, full_index, cppast.BinaryOperatorKind.Add)
else:
if view.rank() == 1:
full_index = indices[0]
elif view.rank() == 2:
if view.layout is Layout.LayoutRight:
full_index = cppast.BinaryOperator(cppast.BinaryOperator(indices[0], get_stride_name(name, 0), cppast.BinaryOperatorKind.Mul), indices[1], cppast.BinaryOperatorKind.Add)
elif view.layout is Layout.LayoutLeft:
full_index = cppast.BinaryOperator(cppast.BinaryOperator(indices[1], get_stride_name(name, 1), cppast.BinaryOperatorKind.Mul), indices[0], cppast.BinaryOperatorKind.Add)

return cppast.ArraySubscriptExpr(restrict_name, [full_index])

Expand Down
6 changes: 5 additions & 1 deletion pykokkos/core/visitors/pykokkos_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,11 @@ def visit_Subscript(self, node: ast.Subscript) -> Union[cppast.ArraySubscriptExp
unfused_name: str = r.group(1) if r else ref.declname

if "PK_RESTRICT" in os.environ and unfused_name in self.restrict_views or name in self.restrict_views:
subscript = index_restrict_view(ref, args)
if unfused_name in self.restrict_views:
v = self.restrict_views[unfused_name]
else:
v = self.restrict_views[name]
subscript = index_restrict_view(ref, args, v)
else:
subscript = cppast.CallExpr(ref, args)

Expand Down

0 comments on commit bed9349

Please sign in to comment.