diff --git a/pykokkos/core/fusion/access_modes.py b/pykokkos/core/fusion/access_modes.py index e6514ec..bea5321 100644 --- a/pykokkos/core/fusion/access_modes.py +++ b/pykokkos/core/fusion/access_modes.py @@ -114,9 +114,10 @@ def visit_Call(self, node: ast.Call) -> None: # Treat function calls like a black box for arg in node.args: if not isinstance(arg, ast.Name): - continue + self.visit(arg) - if arg.id in self.view_args: + # If an entire view is passed to a function + elif arg.id in self.view_args: rank: int = self.view_args[arg.id] for i in range(rank): self.access_indices[(arg.id, i)] = (AccessIndex.All, AccessMode.ReadWrite, "") diff --git a/pykokkos/core/fusion/trace.py b/pykokkos/core/fusion/trace.py index af22d9b..a312727 100644 --- a/pykokkos/core/fusion/trace.py +++ b/pykokkos/core/fusion/trace.py @@ -118,14 +118,7 @@ def log_operation( access_modes: Dict[str, AccessMode] dependencies, access_modes = self.get_data_dependencies(kwargs, AST, cache_key) - access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]] - - if cache_key in self.safety_cache: - access_indices = self.safety_cache[cache_key] - else: - access_indices = self.get_safety_info(kwargs, AST) - self.safety_cache[cache_key] = access_indices - + access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]] = self.get_safety_info(kwargs, AST, cache_key) tracer_op = TracerOperation(self.op_id, future, name, policy, workunit, operation, parser, entity_name, dict(kwargs), dependencies, access_indices) self.op_id += 1 @@ -133,12 +126,13 @@ def log_operation( self.operations[tracer_op] = None - def get_safety_info(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) -> Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]]: + def get_safety_info(self, kwargs: Dict[str, Any], AST: ast.FunctionDef, cache_key: Tuple[str, str]) -> Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]]: """ Get the view access indices needed to check for safety :param kwargs: the keyword arguments passed to the workunit :param AST: the AST of the input workunit + :param cache_key: used to cache the safety info extracted from the AST :returns: the set of data dependencies and the access modes of the views """ @@ -154,7 +148,13 @@ def get_safety_info(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) -> Dict[ # Map from view name (str) + dimension (int) to the type of # access to that view's dimension - write_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]] = get_view_write_indices_and_modes(AST, view_name_and_rank) + write_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]] + + if cache_key in self.safety_cache: + write_indices = self.safety_cache[cache_key] + else: + write_indices = get_view_write_indices_and_modes(AST, view_name_and_rank) + self.safety_cache[cache_key] = write_indices # Now need to convert view name to view ID safety_info: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]] = {}