Skip to content

Commit

Permalink
Copy over the rank-zero utilities (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Sep 6, 2022
1 parent d821310 commit 6db8035
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/lightning_utilities/__about__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time

__version__ = "0.2.0"
__version__ = "0.3.0dev"
__author__ = "Lightning AI et al."
__author_email__ = "[email protected]"
__license__ = "Apache-2.0"
Expand Down
89 changes: 89 additions & 0 deletions src/lightning_utilities/core/rank_zero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Utilities that can be used for calling functions on a particular rank."""
import logging
import warnings
from functools import wraps
from platform import python_version
from typing import Any, Callable, Optional, Union

log = logging.getLogger(__name__)


def rank_zero_only(fn: Callable) -> Callable:
"""Function that can be used as a decorator to enable a function/method being called only on global rank 0."""

@wraps(fn)
def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]:
rank = getattr(rank_zero_only, "rank", None)
if rank is None:
raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
if rank == 0:
return fn(*args, **kwargs)
return None

return wrapped_fn


def _debug(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None:
if python_version() >= "3.8.0":
kwargs["stacklevel"] = stacklevel
log.debug(*args, **kwargs)


@rank_zero_only
def rank_zero_debug(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None:
"""Function used to log debug-level messages only on global rank 0."""
_debug(*args, stacklevel=stacklevel, **kwargs)


def _info(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None:
if python_version() >= "3.8.0":
kwargs["stacklevel"] = stacklevel
log.info(*args, **kwargs)


@rank_zero_only
def rank_zero_info(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None:
"""Function used to log info-level messages only on global rank 0."""
_info(*args, stacklevel=stacklevel, **kwargs)


def _warn(message: Union[str, Warning], stacklevel: int = 2, **kwargs: Any) -> None:
warnings.warn(message, stacklevel=stacklevel, **kwargs)


@rank_zero_only
def rank_zero_warn(message: Union[str, Warning], stacklevel: int = 4, **kwargs: Any) -> None:
"""Function used to log warn-level messages only on global rank 0."""
_warn(message, stacklevel=stacklevel, **kwargs)


rank_zero_deprecation_category = DeprecationWarning


def rank_zero_deprecation(message: Union[str, Warning], stacklevel: int = 5, **kwargs: Any) -> None:
category = kwargs.pop("category", rank_zero_deprecation_category)
rank_zero_warn(message, stacklevel=stacklevel, category=category, **kwargs)


def rank_prefixed_message(message: str, rank: Optional[int]) -> str:
if rank is not None:
# specify the rank of the process being logged
return f"[rank: {rank}] {message}"
return message


class WarningCache(set):
def warn(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None:
if message not in self:
self.add(message)
rank_zero_warn(message, stacklevel=stacklevel, **kwargs)

def deprecation(self, message: str, stacklevel: int = 6, **kwargs: Any) -> None:
if message not in self:
self.add(message)
rank_zero_deprecation(message, stacklevel=stacklevel, **kwargs)

def info(self, message: str, stacklevel: int = 5, **kwargs: Any) -> None:
if message not in self:
self.add(message)
rank_zero_info(message, stacklevel=stacklevel, **kwargs)
18 changes: 18 additions & 0 deletions tests/unittests/core/test_rank_zero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest

from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only


def test_rank_zero_only_raises():
foo = rank_zero_only(lambda x: x + 1)
with pytest.raises(RuntimeError, match="rank_zero_only.rank` needs to be set "):
foo(1)


@pytest.mark.parametrize("rank", [0, 1, 4])
def test_rank_prefixed_message(rank):
rank_zero_only.rank = rank
message = rank_prefixed_message("bar", rank)
assert message == f"[rank: {rank}] bar"
# reset
del rank_zero_only.rank

0 comments on commit 6db8035

Please sign in to comment.