diff --git a/distributed/system.py b/distributed/system.py index 95d8df1086c..057b9df6728 100644 --- a/distributed/system.py +++ b/distributed/system.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import sys import psutil @@ -7,6 +8,9 @@ __all__ = ("memory_limit", "MEMORY_LIMIT") +logger = logging.getLogger(__name__) + + def memory_limit() -> int: """Get the memory limit (in bytes) for this system. @@ -21,6 +25,7 @@ def memory_limit() -> int: # Check cgroups if available # Note: can't use LINUX and WINDOWS constants as they upset mypy if sys.platform == "linux": + path_used = None for path in [ "/sys/fs/cgroup/memory/memory.limit_in_bytes", # cgroups v1 hard limit "/sys/fs/cgroup/memory/memory.soft_limit_in_bytes", # cgroups v1 soft limit @@ -31,9 +36,15 @@ def memory_limit() -> int: with open(path) as f: cgroups_limit = int(f.read()) if cgroups_limit > 0: + path_used = path limit = min(limit, cgroups_limit) except Exception: pass + if path_used: + logger.debug( + "Setting system memory limit based on cgroup value defined in %s", + path_used, + ) # Check rlimit if available if sys.platform != "win32": @@ -41,8 +52,11 @@ def memory_limit() -> int: import resource hard_limit = resource.getrlimit(resource.RLIMIT_RSS)[1] - if hard_limit > 0: - limit = min(limit, hard_limit) + if 0 < hard_limit < limit: + logger.debug( + "Limiting system memory based on RLIMIT_RSS to %s", hard_limit + ) + limit = hard_limit except (ImportError, OSError): pass diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index b6b04c9969e..9b6314a3d74 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3721,3 +3721,21 @@ async def test_deprecation_of_renamed_worker_attributes(s, a, b): ) with pytest.warns(DeprecationWarning, match=msg): assert a.outgoing_current_count == a.transfer_outgoing_count + + +@gen_cluster(nthreads=[]) +async def test_worker_log_memory_limit_too_high(s): + with captured_logger("distributed.worker_memory") as caplog: + # caplog.set_level(logging.WARN, logger="distributed.worker") + async with Worker(s.address, memory_limit="1PB"): + pass + + expected_snippets = [ + ("ignore", "ignoring"), + ("memory limit", "memory_limit"), + ("system"), + ("1PB"), + ] + for snippets in expected_snippets: + # assert any(snip in caplog.text for snip in snippets) + assert any(snip in caplog.getvalue().lower() for snip in snippets) diff --git a/distributed/worker_memory.py b/distributed/worker_memory.py index a5516f32019..4a67d751c77 100644 --- a/distributed/worker_memory.py +++ b/distributed/worker_memory.py @@ -385,7 +385,7 @@ def parse_memory_limit( ) -> int | None: if memory_limit is None: return None - + orig = memory_limit if memory_limit == "auto": memory_limit = int(system.MEMORY_LIMIT * min(1, nthreads / total_cores)) with suppress(ValueError, TypeError): @@ -401,7 +401,15 @@ def parse_memory_limit( assert isinstance(memory_limit, int) if memory_limit == 0: return None - return min(memory_limit, system.MEMORY_LIMIT) + if system.MEMORY_LIMIT < memory_limit: + logger.warning( + "Ignoring provided memory limit %s due to system memory limit of %s", + orig, + format_bytes(system.MEMORY_LIMIT), + ) + return system.MEMORY_LIMIT + else: + return memory_limit def _parse_threshold(