Skip to content

Commit

Permalink
docs: improved docstring on methods
Browse files Browse the repository at this point in the history
added more details on docstring of methods
  • Loading branch information
Rares Gaia committed Aug 5, 2024
1 parent b883db8 commit af46f30
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 10 deletions.
64 changes: 56 additions & 8 deletions multiworld/world_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,13 @@ async def _wait_work(self, work: Work, world_name: str) -> None:
async def send(
self, tensor: Tensor, dst: int, world_name: str = DEFAULT_WORLD_NAME
) -> None:
"""Send a tensor to a destination in a world."""
"""
Send a tensor to a destination in a world.
Args:
tensor (Tensor): Tensor to be sent.
dst (int): Destination rank from the world.
world_name (str): Name of the world.
"""
try:
with concurrent.futures.ThreadPoolExecutor() as pool:
work = await self._loop.run_in_executor(
Expand All @@ -122,7 +128,13 @@ async def send(
async def recv(
self, tensor: Tensor, src: int, world_name: str = DEFAULT_WORLD_NAME
) -> None:
"""Receive a tensor from a specific rank in a world."""
"""
Receive a tensor from a specific rank in a world.
Args:
tensor (Tensor): Tensor to be sent.
src (int): Source rank.
world_name (str): Name of the world.
"""
try:
with concurrent.futures.ThreadPoolExecutor() as pool:
work = await self._loop.run_in_executor(
Expand All @@ -142,7 +154,13 @@ async def recv(
async def broadcast(
self, tensor: Tensor, src: int, world_name: str = DEFAULT_WORLD_NAME
) -> None:
"""Broadcast a tensor to the world from a source (src)."""
"""
Broadcast a tensor to the world from a source (src).
Args:
tensor (Tensor): Tensor to be sent.
src (int): Source of the broadcast.
world_name (str): Name of the world.
"""
try:
with concurrent.futures.ThreadPoolExecutor() as pool:
work = await self._loop.run_in_executor(
Expand All @@ -165,7 +183,13 @@ async def all_reduce(
op: dist.ReduceOp = dist.ReduceOp.SUM,
world_name: str = DEFAULT_WORLD_NAME,
) -> None:
"""Do all-reduce for a given tensor in a world."""
"""
Do all-reduce for a given tensor in a world.
Args:
tensor (Tensor): Tensor to be sent.
op (dist.ReduceOp): Reduce operation
world_name (str): Name of the world.
"""
try:
with concurrent.futures.ThreadPoolExecutor() as pool:
work = await self._loop.run_in_executor(
Expand All @@ -191,7 +215,11 @@ async def reduce(
) -> None:
"""Do reduce for a given tensor in a world.
The rank is a receiver of the final result.
Args:
tensor (Tensor): Tensor to be gathered.
dst (int): Rank to recieve the reduced tensors.
op (dist.ReduceOp): Reduce operation
world_name (str): Name of the world.
"""
try:
with concurrent.futures.ThreadPoolExecutor() as pool:
Expand All @@ -216,7 +244,13 @@ async def all_gather(
tensor: Tensor,
world_name: str = DEFAULT_WORLD_NAME,
) -> None:
"""Do all-gather for a given tensor in a world."""
"""
Do all-gather for a given tensor in a world.
Args:
tensors (list[Tensor]): List of tensors.
tensor (Tensor): Tensor to be gathered.
world_name (str): Name of the world.
"""
try:
with concurrent.futures.ThreadPoolExecutor() as pool:
work = await self._loop.run_in_executor(
Expand All @@ -240,7 +274,14 @@ async def gather(
dst: int = 0,
world_name: str = DEFAULT_WORLD_NAME,
) -> None:
"""Do gather for a list of tensors in a world."""
"""
Do gather for a list of tensors in a world.
Args:
tensor (Tensor): Tensor to be gathered.
gather_list (list(Tensor)): Gather list
dst (int): Rank to recieve the gathered tensors.
world_name (str): Name of the world.
"""
try:
with concurrent.futures.ThreadPoolExecutor() as pool:
work = await self._loop.run_in_executor(
Expand All @@ -265,7 +306,14 @@ async def scatter(
src: int = 0,
world_name: str = DEFAULT_WORLD_NAME,
) -> None:
"""Do scatter for a list of tensors from a source (src) in a world."""
"""
Do scatter for a list of tensors from a source (src) in a world.
Args:
tensor (Tensor): Tensor to be gathered.
scatter_list (list(Tensor)): Scatter list
src (int): Rank that scatters tensors.
world_name (str): Name of the world.
"""
try:
with concurrent.futures.ThreadPoolExecutor() as pool:
work = await self._loop.run_in_executor(
Expand Down
16 changes: 14 additions & 2 deletions multiworld/world_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def __init__(self, enable_monitor=True):
_ = asyncio.create_task(self._cleanup_worlds())

def cleanup(self):
"""Call os._exit(0) explicitly."""
"""
Push out all the data that has been buffered after explicitly alling os._exit(0)
"""
# TODO: This is a temporary workaround to prevent main thread hang
# even after it's done. Calling os._exit(0) guarantees
# terminationof the process. We need to figure out why
Expand Down Expand Up @@ -118,7 +120,17 @@ async def initialize_world(
addr: str = "127.0.0.1",
port: int = -1,
):
"""Initialize world."""
"""
Initialize world for given rank using world name backend port and addr.
Args:
world_name (str): Name of the world.
rank (int): Rank of the process.
world_size (int): Size of the world.
backend (str): Backend used for communication.
addr (str): IP address
port (int): Port number
"""
self.add_world(world_name)

loop = asyncio.get_running_loop()
Expand Down

0 comments on commit af46f30

Please sign in to comment.