Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizations: fix indexing of restrict views adding extra operations #289

Merged
merged 1 commit into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions pykokkos/core/optimizations/restrict_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from pykokkos.core import cppast
from pykokkos.interface import Subview, View, ViewType, Trait
from pykokkos.interface import Layout, Subview, View, ViewType, Trait


def may_share_memory(a, b) -> bool:
Expand Down Expand Up @@ -59,7 +59,7 @@ def may_share_memory(a, b) -> bool:

return False

def get_restrict_views(views: Dict[str, ViewType]) -> Tuple[Set[str], str]:
def get_restrict_views(views: Dict[str, ViewType]) -> Tuple[Dict[str, ViewType], str]:
"""
Identify views that do not alias each other to apply the restrict
keyword to
Expand All @@ -69,7 +69,7 @@ def get_restrict_views(views: Dict[str, ViewType]) -> Tuple[Set[str], str]:
"""

# Map from base view id() to a list of views that alias that view
base_view_ids: Dict[int, Set[str]] = {}
base_view_ids: Dict[int, Dict[str, ViewType]] = {}
# Map from view name to the xp_array it has
xp_arrays: Dict[str, Any] = {}

Expand Down Expand Up @@ -98,11 +98,11 @@ def get_restrict_views(views: Dict[str, ViewType]) -> Tuple[Set[str], str]:

view_id: int = id(base_view)
if view_id in base_view_ids:
base_view_ids[view_id].add(view_name)
base_view_ids[view_id][view_name] = view
else:
base_view_ids[view_id] = {view_name}
base_view_ids[view_id] = {view_name: view}

restricted_views: Set[str] = set()
restricted_views: Dict[str, ViewType] = {}
for view_id, view_set in base_view_ids.items():
# if len(view_set) == 1:
restricted_views.update(view_set)
Expand All @@ -127,7 +127,9 @@ def get_restrict_views(views: Dict[str, ViewType]) -> Tuple[Set[str], str]:
aliasing_arrays.add(name)
aliasing_arrays.add(other_name)

restricted_views -= aliasing_arrays
for arr in aliasing_arrays:
restricted_views.pop(arr, None)

restricted_signature: str = hashlib.md5("".join(sorted(restricted_views)).encode()).hexdigest()

return restricted_views, restricted_signature
Expand Down Expand Up @@ -219,7 +221,7 @@ def define_restrict_function(functor: cppast.RecordDecl, operation: str, workuni
return method


def index_restrict_view(name: cppast.DeclRefExpr, indices: List[cppast.Expr]) -> cppast.ArraySubscriptExpr:
def index_restrict_view(name: cppast.DeclRefExpr, indices: List[cppast.Expr], view: ViewType) -> cppast.ArraySubscriptExpr:
"""
Get the indexing operation of a particular view with a given list
of indices
Expand All @@ -229,12 +231,24 @@ def index_restrict_view(name: cppast.DeclRefExpr, indices: List[cppast.Expr]) ->
"""

restrict_name: cppast.DeclRefExpr = get_restrict_ptr_name(name)
full_index: cppast.Expr

# Subviews could be strided
if isinstance(view, Subview) or view.rank() > 2:
full_index = cppast.BinaryOperator(indices[0], get_stride_name(name, 0), cppast.BinaryOperatorKind.Mul)
for i, index in enumerate(indices[1:]):
current_stride: cppast.DeclRefExpr = get_stride_name(name, i + 1) # Add one since we did zero before the loop
current_mul = cppast.BinaryOperator(index, current_stride, cppast.BinaryOperatorKind.Mul)
full_index = cppast.BinaryOperator(current_mul, full_index, cppast.BinaryOperatorKind.Add)

full_index: cppast.Expr = cppast.BinaryOperator(indices[0], get_stride_name(name, 0), cppast.BinaryOperatorKind.Mul)
for i, index in enumerate(indices[1:]):
current_stride: cppast.DeclRefExpr = get_stride_name(name, i + 1) # Add one since we did zero before the loop
current_mul = cppast.BinaryOperator(index, current_stride, cppast.BinaryOperatorKind.Mul)
full_index = cppast.BinaryOperator(current_mul, full_index, cppast.BinaryOperatorKind.Add)
else:
if view.rank() == 1:
full_index = indices[0]
elif view.rank() == 2:
if view.layout is Layout.LayoutRight:
full_index = cppast.BinaryOperator(cppast.BinaryOperator(indices[0], get_stride_name(name, 0), cppast.BinaryOperatorKind.Mul), indices[1], cppast.BinaryOperatorKind.Add)
elif view.layout is Layout.LayoutLeft:
full_index = cppast.BinaryOperator(cppast.BinaryOperator(indices[1], get_stride_name(name, 1), cppast.BinaryOperatorKind.Mul), indices[0], cppast.BinaryOperatorKind.Add)

return cppast.ArraySubscriptExpr(restrict_name, [full_index])

Expand Down
6 changes: 5 additions & 1 deletion pykokkos/core/visitors/pykokkos_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,11 @@ def visit_Subscript(self, node: ast.Subscript) -> Union[cppast.ArraySubscriptExp
unfused_name: str = r.group(1) if r else ref.declname

if "PK_RESTRICT" in os.environ and unfused_name in self.restrict_views or name in self.restrict_views:
subscript = index_restrict_view(ref, args)
if unfused_name in self.restrict_views:
v = self.restrict_views[unfused_name]
else:
v = self.restrict_views[name]
subscript = index_restrict_view(ref, args, v)
else:
subscript = cppast.CallExpr(ref, args)

Expand Down
Loading