Skip to content

Commit

Permalink
Merge pull request #22 from tskisner/sysv
Browse files Browse the repository at this point in the history
Switch to using the sysv_ipc module
  • Loading branch information
tskisner authored Jan 31, 2024
2 parents 14c111b + d41d5e6 commit 7dbb4ed
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 64 deletions.
2 changes: 1 addition & 1 deletion pshmem/locking.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
##
# Copyright (c) 2017-2020, all rights reserved. Use of this source code
# Copyright (c) 2017-2024, all rights reserved. Use of this source code
# is governed by a BSD license that can be found in the top-level
# LICENSE file.
##
Expand Down
82 changes: 22 additions & 60 deletions pshmem/shmem.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
##
# Copyright (c) 2017-2020, all rights reserved. Use of this source code
# Copyright (c) 2017-2024, all rights reserved. Use of this source code
# is governed by a BSD license that can be found in the top-level
# LICENSE file.
##

import sys
import mmap
import uuid

import numpy as np
import posix_ipc
import sysv_ipc

from .utils import mpi_data_type
from .utils import mpi_data_type, random_shm_key


class MPIShared(object):
Expand Down Expand Up @@ -147,16 +145,19 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
# and a unique random ID.

self._name = None
self._shm_index = None
if self._rank == 0:
rng_str = uuid.uuid4().hex[:12]
self._name = f"MPIShared_{rng_str}"
# Get a random 64bit integer between the supported range of keys
self._shm_index = random_shm_key()
# Name, just used for printing
self._name = f"MPIShared_{self._shm_index}"
if self._comm is not None:
self._shm_index = self._comm.bcast(self._shm_index, root=0)
self._name = self._comm.bcast(self._name, root=0)

# Only allocate our buffers if the total number of elements is > 0

self._shmem = None
self._shmap = None
self._flat = None
self.data = None

Expand All @@ -176,9 +177,9 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
# First rank on each node creates the buffer
if self._noderank == 0:
try:
self._shmem = posix_ipc.SharedMemory(
self._name,
posix_ipc.O_CREX,
self._shmem = sysv_ipc.SharedMemory(
self._shm_index,
flags=sysv_ipc.IPC_CREX,
size=int(nbytes),
)
except Exception as e:
Expand All @@ -190,27 +191,6 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
msg += ": {}".format(e)
print(msg, flush=True)
raise
try:
# MMap the shared memory
self._shmap = mmap.mmap(
self._shmem.fd,
self._shmem.size,
)
except Exception as e:
msg = "Process {}: {}".format(self._rank, self._name)
msg += " failed MMap of {} bytes".format(nbytes)
msg += " ({} elements of {} bytes each)".format(
self._n, self._dsize
)
msg += ": {}".format(e)
print(msg, flush=True)
# Try to free the shared memory object
try:
self._shmem.close_fd()
self._shmem.unlink()
except Exception as eclose:
pass
raise

# Wait for that to be created
if self._nodecomm is not None:
Expand All @@ -219,11 +199,8 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
# Other ranks on the node attach
if self._noderank != 0:
try:
self._shmem = posix_ipc.SharedMemory(self._name)
# MMap the shared memory
self._shmap = mmap.mmap(
self._shmem.fd,
self._shmem.size,
self._shmem = sysv_ipc.SharedMemory(
self._shm_index, flags=0, size=0
)
except Exception as e:
msg = "Process {}: {}".format(self._rank, self._name)
Expand All @@ -239,22 +216,15 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
if self._nodecomm is not None:
self._nodecomm.barrier()

# Now that all processes have mmap'ed the shared memory we can
# close the shared memory handle
self._shmem.close_fd()

# Wait for all processes to close file handle
if self._nodecomm is not None:
self._nodecomm.barrier()

# One process requests the file to be deleted, but this will not
# actually happen until all processes release their mmap.
# Now the rank zero process will call remove() to mark the shared
# memory segment for removal. However, this will not actually
# be removed until all processes detach.
if self._noderank == 0:
try:
self._shmem.unlink()
except posix_ipc.ExistentialError:
self._shmem.remove()
except sysv_ipc.ExistentialError:
msg = "Process {}: {}".format(self._rank, self._name)
msg += " failed to unlink shared memory"
msg += " failed to remove shared memory"
msg += ": {}".format(e)
print(msg, flush=True)
raise
Expand All @@ -263,7 +233,7 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
self._flat = np.ndarray(
self._n,
dtype=self._dtype,
buffer=self._shmap,
buffer=self._shmem,
)
# Initialize to zero.
if self._noderank == 0:
Expand All @@ -272,8 +242,6 @@ def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
# Wrap
self.data = self._flat.reshape(self._shape)



def __del__(self):
self.close()

Expand Down Expand Up @@ -399,17 +367,11 @@ def close(self):
del self.data
if hasattr(self, "_flat"):
del self._flat
if hasattr(self, "_shmap"):
# Close the mmap'ed memory
if self._shmap is not None:
self._shmap.close()
del self._shmap
self._shmap = None
if hasattr(self, "_shmem"):
if self._shmem is not None:
self._shmem.detach()
del self._shmem
self._shmem = None

self._flat = None
self.data = None

Expand Down
15 changes: 14 additions & 1 deletion pshmem/test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
##
# Copyright (c) 2017-2020, all rights reserved. Use of this source code
# Copyright (c) 2017-2024, all rights reserved. Use of this source code
# is governed by a BSD license that can be found in the top-level
# LICENSE file.
##
Expand Down Expand Up @@ -425,6 +425,19 @@ def test_zero(self):
except RuntimeError:
print("successful raise with no data during set()", flush=True)

# def test_hang(self):
# # Run this while monitoring memory usage (e.g. with htop) and then
# # do kill -9 on one of the processes to verify that the kernel
# # releases shared memory.
# dims = (200, 1000000)
# dt = np.float64
# shm = MPIShared(dims, dt, self.comm)
# import time
# time.sleep(60)
# shm.close()
# del shm
# return


class LockTest(unittest.TestCase):
def setUp(self):
Expand Down
22 changes: 21 additions & 1 deletion pshmem/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
##
# Copyright (c) 2017-2020, all rights reserved. Use of this source code
# Copyright (c) 2017-2024, all rights reserved. Use of this source code
# is governed by a BSD license that can be found in the top-level
# LICENSE file.
##

import random

import numpy as np
import sysv_ipc


def mpi_data_type(comm, dt):
Expand Down Expand Up @@ -42,3 +45,20 @@ def mpi_data_type(comm, dt):
raise
dsize = mpitype.Get_size()
return (dsize, mpitype)


def random_shm_key():
"""Get a random 64bit integer in the range supported by shmget()
The python random library is used, and seeded with the default source
(either system time or os.urandom).
Returns:
(int): The random integer.
"""
min_val = sysv_ipc.KEY_MIN
max_val = sysv_ipc.KEY_MAX
# Seed with default source of randomness
random.seed(a=None)
return random.randint(min_val, max_val)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def readme():
scripts=None,
license="BSD",
python_requires=">=3.8.0",
install_requires=["numpy", "posix_ipc"],
install_requires=["numpy", "sysv_ipc"],
extras_require={"mpi": ["mpi4py>=3.0"]},
cmdclass=versioneer.get_cmdclass(),
classifiers=[
Expand Down

0 comments on commit 7dbb4ed

Please sign in to comment.