Skip to content

Commit

Permalink
Remove _get_backend_opts from StencilConfig (#354)
Browse files Browse the repository at this point in the history
  • Loading branch information
jdahm authored Oct 13, 2022
1 parent 4c5b18d commit 882db88
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 37 deletions.
1 change: 0 additions & 1 deletion dsl/pace/dsl/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,6 @@ def __init__(
gt_cache=gt4py.config.cache_settings["dir_name"]
),
)
self._argument_names = tuple(inspect.getfullargspec(func).args)

assert (
len(self._argument_names) > 0
Expand Down
41 changes: 6 additions & 35 deletions dsl/pace/dsl/stencil_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import dataclasses
import enum
import hashlib
import re
from typing import Any, Callable, Dict, Hashable, Iterable, Optional, Sequence, Tuple

from gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline
Expand Down Expand Up @@ -172,11 +171,13 @@ class StencilConfig(Hashable):
dace_config: Optional[DaceConfig] = None

def __post_init__(self):
self.backend_opts = self._get_backend_opts(
self.compilation_config.device_sync, self.compilation_config.format_source
)
self.backend_opts = {
"device_sync": self.compilation_config.device_sync,
"format_source": self.compilation_config.format_source,
}
self._hash = self._compute_hash()
# We need a DaceConfig to known our orchestration as part of the build system

# We need a DaceConfig to know if orchestration is part of the build system
# but we can't hash it very well (for now). The workaround is to make
# sure we have a default Python orchestrated config.
if self.dace_config is None:
Expand Down Expand Up @@ -216,36 +217,6 @@ def __eq__(self, other):
except AttributeError:
return False

def _get_backend_opts(
self,
device_sync: Optional[bool] = None,
format_source: Optional[bool] = None,
) -> Dict[str, Any]:
backend_opts: Dict[str, Any] = {}
all_backend_opts: Optional[Dict[str, Any]] = {
"device_sync": {
"backend": r".*(gpu|cuda)$",
"value": False,
},
"format_source": {
"value": False,
},
"verbose": {"backend": r"(gt:|cuda)", "value": False},
}
for name, option in all_backend_opts.items():
using_option_backend = re.match(
option.get("backend", ""), self.compilation_config.backend
)
if "backend" not in option or using_option_backend:
backend_opts[name] = option["value"]

if device_sync is not None:
backend_opts["device_sync"] = device_sync
if format_source is not None:
backend_opts["format_source"] = format_source

return backend_opts

def stencil_kwargs(
self, *, func: Callable[..., None], skip_passes: Iterable[str] = ()
):
Expand Down
1 change: 0 additions & 1 deletion tests/main/dsl/test_stencil_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,6 @@ def test_backend_options(
"device_sync": False,
"format_source": False,
"name": "test_stencil_wrapper.copy_stencil",
"verbose": False,
},
}

Expand Down

0 comments on commit 882db88

Please sign in to comment.