diff --git a/atintegrators/TestRandomPass.c b/atintegrators/TestRandomPass.c index 030ed570f..858578468 100644 --- a/atintegrators/TestRandomPass.c +++ b/atintegrators/TestRandomPass.c @@ -7,6 +7,9 @@ #include "atelem.c" #include "atlalib.c" #include "atrandom.c" +#ifdef MPI +#include +#endif struct elem { @@ -19,22 +22,26 @@ static void RandomPass(double *r_in, int num_particles) { double common_val = atrandn_r(common_rng, 0.0, 0.001); +#ifdef MPI + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); +#else + int rank = 0; +#endif /* MPI */ for (int c = 0; c OMP_PARTICLE_THRESHOLD) default(none) \ shared(r_in, num_particles, common_val, thread_rng) for (int c = 0; chasSpare) { + rng->hasSpare = false; + return mean + stdDev * rng->spare; } - hasSpare = true; + rng->hasSpare = true; do { u = 2.0 * atrandd_r(rng) - 1.0; v = 2.0 * atrandd_r(rng) - 1.0; @@ -116,7 +116,7 @@ static double atrandn_r(pcg32_random_t* rng, double mean, double stdDev) } while ((s >= 1.0) || (s == 0.0)); s = sqrt(-2.0 * log(s) / s); - spare = v * s; + rng->spare = v * s; return mean + stdDev * u * s; } diff --git a/pyat/at.c b/pyat/at.c index 32dd0fd67..50ca5867e 100644 --- a/pyat/at.c +++ b/pyat/at.c @@ -9,6 +9,9 @@ #include #include #endif /*_OPENMP*/ +#ifdef MPI +#include +#endif /* MPI */ #include "attypes.h" #include #include @@ -694,10 +697,16 @@ static PyObject *at_elempass(PyObject *self, PyObject *args, PyObject *kwargs) static PyObject *reset_rng(PyObject *self, PyObject *args, PyObject *kwargs) { static char *kwlist[] = {"rank", "seed", NULL}; - uint64_t rank = 0; uint64_t seed = AT_RNG_STATE; +#ifdef MPI + int irank; + MPI_Comm_rank(MPI_COMM_WORLD, &irank); + uint64_t rank = irank; +#else + uint64_t rank = 0; +#endif /* MPI */ - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|K$K", kwlist, + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|$KK", kwlist, &rank, &seed)) { return NULL; } @@ -756,11 +765,13 @@ static PyMethodDef AtMethods[] = { ":meta private:" )}, {"reset_rng", (PyCFunction)reset_rng, METH_VARARGS | METH_KEYWORDS, - PyDoc_STR("reset_rng(rank=0, seed=None)\n\n" + PyDoc_STR("reset_rng(*, rank=0, seed=None)\n\n" "Reset the *common* and *thread* random generators.\n\n" + "The seed is applied unchanged to the \"common\" generator, and modified in a\n" + "thread-specific way to the \"thread\" generator\n\n" "Parameters:\n" " rank (int): thread identifier (for MPI and python multiprocessing)\n" - " seed (int): single seed for both generators\n" + " seed (int): single seed for both generators. Default: initial seed\n" )}, {"common_rng", (PyCFunction)common_rng, METH_NOARGS, PyDoc_STR("common_rng()\n\n" diff --git a/pyat/at/tracking/__init__.py b/pyat/at/tracking/__init__.py index d4efde3eb..9d3497556 100644 --- a/pyat/at/tracking/__init__.py +++ b/pyat/at/tracking/__init__.py @@ -1,11 +1,10 @@ """ Tracking functions """ -from ..lattice import DConstant from .atpass import reset_rng, common_rng, thread_rng from .track import * from .particles import * from .utils import * from .deprecated import * # initialise the C random generators -reset_rng(DConstant.rank) +reset_rng() diff --git a/pyat/at/tracking/atpass.pyi b/pyat/at/tracking/atpass.pyi index d321c39b4..94b9e9ea0 100644 --- a/pyat/at/tracking/atpass.pyi +++ b/pyat/at/tracking/atpass.pyi @@ -20,6 +20,6 @@ def elempass(element: Element, r_in, particle: Optional[Particle] = None, ): ... -def reset_rng(rank: int = 0, seed: Optional[int] = None) -> None: ... +def reset_rng(*, rank: int = 0, seed: Optional[int] = None) -> None: ... def common_rng() -> float: ... def thread_rng() -> float: ... diff --git a/pyat/at/tracking/track.py b/pyat/at/tracking/track.py index 4b6a70f2e..b6b83654e 100644 --- a/pyat/at/tracking/track.py +++ b/pyat/at/tracking/track.py @@ -23,14 +23,14 @@ def _atpass_fork(seed, rank, rin, **kwargs): """Single forked job""" - reset_rng(rank, seed=seed) + reset_rng(rank=rank, seed=seed) result = _atpass(_globring, rin, **kwargs) return rin, result def _atpass_spawn(ring, seed, rank, rin, **kwargs): """Single spawned job""" - reset_rng(rank, seed=seed) + reset_rng(rank=rank, seed=seed) result = _atpass(ring, rin, **kwargs) return rin, result