From da14715cca2d2174d6f5554444e84f1a5d92cdf9 Mon Sep 17 00:00:00 2001 From: ksagiyam <46749170+ksagiyam@users.noreply.github.com> Date: Wed, 27 Sep 2023 16:44:55 +0100 Subject: [PATCH] fix dat version (#709) --- pyop2/parloop.py | 12 ++++++++++++ pyop2/types/dat.py | 4 ++++ test/unit/test_dats.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 46 insertions(+) diff --git a/pyop2/parloop.py b/pyop2/parloop.py index ef62a1878..48e73ecd1 100644 --- a/pyop2/parloop.py +++ b/pyop2/parloop.py @@ -17,6 +17,7 @@ from pyop2.local_kernel import LocalKernel, CStringLocalKernel, LoopyLocalKernel from pyop2.types import (Access, Global, AbstractDat, Dat, DatView, MixedDat, Mat, Set, MixedSet, ExtrudedSet, Subset, Map, ComposedMap, MixedMap) +from pyop2.types.data_carrier import DataCarrier from pyop2.utils import cached_property @@ -209,6 +210,7 @@ def compute(self): @mpi.collective def __call__(self): """Execute the kernel over all members of the iteration space.""" + self.increment_dat_version() self.zero_global_increments() orig_lgmaps = self.replace_lgmaps() self.global_to_local_begin() @@ -223,6 +225,16 @@ def __call__(self): self.finalize_global_increments() self.local_to_global_end() + def increment_dat_version(self): + """Increment dat versions of :class:`DataCarrier`s in the arguments.""" + for lk_arg, gk_arg, pl_arg in self.zipped_arguments: + assert isinstance(pl_arg.data, DataCarrier) + if lk_arg.access is not Access.READ: + if pl_arg.data in self.reduced_globals: + self.reduced_globals[pl_arg.data].data.increment_dat_version() + else: + pl_arg.data.increment_dat_version() + def zero_global_increments(self): """Zero any global increments every time the loop is executed.""" for g in self.reduced_globals.keys(): diff --git a/pyop2/types/dat.py b/pyop2/types/dat.py index 37ac4fd8b..5ed6702a9 100644 --- a/pyop2/types/dat.py +++ b/pyop2/types/dat.py @@ -828,6 +828,10 @@ def what(x): def dat_version(self): return sum(d.dat_version for d in self._dats) + def increment_dat_version(self): + for d in self: + d.increment_dat_version() + def __call__(self, access, path=None): from pyop2.parloop import MixedDatLegacyArg return MixedDatLegacyArg(self, path, access) diff --git a/test/unit/test_dats.py b/test/unit/test_dats.py index 0868fd5bf..d43b5a1e4 100644 --- a/test/unit/test_dats.py +++ b/test/unit/test_dats.py @@ -183,6 +183,17 @@ def test_dat_version(self, s, d1): assert d1.dat_version == 4 assert d2.dat_version == 2 + # ParLoop + d3 = op2.Dat(s ** 1, data=None, dtype=np.uint32) + assert d3.dat_version == 0 + k = op2.Kernel(""" +static void write(unsigned int* v) { + *v = 1; +} +""", "write") + op2.par_loop(k, s, d3(op2.WRITE)) + assert d3.dat_version == 1 + def test_mixed_dat_version(self, s, d1, mdat): """Check object versioning for MixedDat""" d2 = op2.Dat(s) @@ -216,6 +227,25 @@ def test_mixed_dat_version(self, s, d1, mdat): assert mdat.dat_version == 8 assert mdat2.dat_version == 5 + # ParLoop + d3 = op2.Dat(s ** 1, data=None, dtype=np.uint32) + d4 = op2.Dat(s ** 1, data=None, dtype=np.uint32) + d3d4 = op2.MixedDat([d3, d4]) + assert d3.dat_version == 0 + assert d4.dat_version == 0 + assert d3d4.dat_version == 0 + k = op2.Kernel(""" +static void write(unsigned int* v) { + v[0] = 1; + v[1] = 2; +} +""", "write") + m = op2.Map(s, op2.Set(nelems), 1, values=[0, 1, 2, 3, 4]) + op2.par_loop(k, s, d3d4(op2.WRITE, op2.MixedMap([m, m]))) + assert d3.dat_version == 1 + assert d4.dat_version == 1 + assert d3d4.dat_version == 2 + def test_accessing_data_with_halos_increments_dat_version(self, d1): assert d1.dat_version == 0 d1.data_ro_with_halos