From e880536c8ed4d88ccec7248569d48b6cfeda8a17 Mon Sep 17 00:00:00 2001 From: Rares Gaia Date: Wed, 7 Aug 2024 16:54:54 +0300 Subject: [PATCH] refactor: refactor resnet example updated resnet example to use multiworld API --- examples/resnet/m8d.py | 109 +++++++++++++++++++++++------------------ 1 file changed, 61 insertions(+), 48 deletions(-) diff --git a/examples/resnet/m8d.py b/examples/resnet/m8d.py index b9c63c8..49b1bfc 100644 --- a/examples/resnet/m8d.py +++ b/examples/resnet/m8d.py @@ -32,7 +32,6 @@ import argparse import asyncio import atexit -import os import time import torch @@ -64,6 +63,8 @@ LEADER_RANK = 0 # Worker rank in every world is going to be 1 because we are creating 2 processes in every world WORKER_RANK = 1 +STARTING_PORT = 29500 +WORLD_SIZE = 2 def index_to_class_name(index): @@ -71,7 +72,7 @@ def index_to_class_name(index): Get class name from index. Args: - index (int): Index of the class. + index: Index of the class. """ return CIFAR10_CLASS_NAMES[index] @@ -81,7 +82,10 @@ def load_cifar10(batch_size=1): Load CIFAR10 dataset. Args: - batch_size (int): Batch size for the DataLoader. + batch_size: Batch size for the DataLoader. + + Returns: + DataLoader instance """ transform = transforms.Compose( [ @@ -96,35 +100,37 @@ def load_cifar10(batch_size=1): return cifar10_loader -def dummy(world_name, rank, size, backend): +async def dummy(world_name, rank, size, backend, world_communicator): """ Dummy function to be implemented later. Args: - world_name (str): Name of the world. - rank (int): Rank of the process. - size (int): Number of processes. - backend (str): Backend used for communication. + world_name: Name of the world. + rank: Rank of the process. + size: Number of processes. + backend: Backend used for communication. """ print(f"dummy function: world: {world_name}, my rank: {rank}, world size: {size}") + await asyncio.sleep(0) -def run(world_name, rank, size, backend): +async def run(world_name, rank, size, backend, world_communicator): """ Distributed function to be implemented later. Args: - world_name (str): Name of the world. - rank (int): Rank of the process. - size (int): Number of processes. - backend (str): Backend used for communication. + world_name: Name of the world. + rank: Rank of the process. + size: Number of processes. + backend: Backend used for communication. + world_communicator: World communicator """ world_idx = int(world_name[5:]) # Initialize ResNet18 model model = AutoModelForImageClassification.from_pretrained( - "jialicheng/resnet-18-cifar10-21" + "edadaltocg/resnet18_cifar10" ) model.eval() @@ -136,8 +142,12 @@ def run(world_name, rank, size, backend): image_tensor = ( image_tensor.to(f"cuda:{world_idx}") if backend == "nccl" else image_tensor ) - - dist.recv(image_tensor, src=LEADER_RANK) + try: + await world_communicator.recv(image_tensor, LEADER_RANK, world_name) + except Exception as e: + print("RECV FAILED for RANK", rank) + print(f"Caught an except while receiving predicted class: {e}") + break # Inference with torch.no_grad(): @@ -149,13 +159,16 @@ def run(world_name, rank, size, backend): print(f"Predicted : {predicted}, {predicted.shape}, {type(predicted)}") # Send the predicted class back to the leader - dist.send(predicted, dst=LEADER_RANK) + try: + await world_communicator.send(predicted, LEADER_RANK, world_name) + except Exception as e: + print(f"Caught an except while sending image: {e}") + break print(f"Predicted class: {predicted}") world_manager = None -STARTING_PORT = 29500 async def init_world( @@ -165,13 +178,13 @@ async def init_world( Initialize the distributed environment. Args: - world_name (str): Name of the world. - rank (int): Rank of the process. - size (int): Number of processes. - fn (function): Function to be executed. - backend (str): Backend to be used. - addr (str): Address of the leader process. - port (int): Port to be used. + world_name: Name of the world. + rank: Rank of the process. + size: Number of processes. + fn: Function to be executed. + backend: Backend to be used. + addr: Address of the leader process. + port: Port to be used. """ global world_manager @@ -179,11 +192,11 @@ async def init_world( # TODO: make WorldManager as singleton world_manager = dist.WorldManager() - world_manager.initialize_world( + await world_manager.initialize_world( world_name, rank, size, backend=backend, addr=addr, port=port ) - fn(world_name, rank, size, backend) + await fn(world_name, rank, size, backend, world_manager.communicator) def run_init_world( @@ -193,13 +206,13 @@ def run_init_world( Run the init_world function in a separate process. Args: - world_name (str): Name of the world. - rank (int): Rank of the process. - size (int): Number of processes. + world_name: Name of the world. + rank: Rank of the process. + size: Number of processes. fn (function): Function to be executed. - backend (str): Backend to be used. - addr (str): Address of the leader process. - port (int): Port to be used. + backend: Backend to be used. + addr: Address of the leader process. + port: Port to be used. """ asyncio.run(init_world(world_name, rank, size, fn, backend, addr, port)) @@ -212,13 +225,13 @@ async def create_world(world_name, world_size, addr, port, backend, fn1, fn2): Create a world with the given port and world name. Args: - world_name (str): Name of the world. - world_size (int): Number of processes in the world. - addr (str): Address of the leader process. - port (int): Port number. - backend (str): Backend to be used. - fn1 (function): Function to be executed in the world. - fn2 (function): Function to be executed in the world leader. + world_name: Name of the world. + world_size: Number of processes in the world. + addr: Address of the leader process. + port: Port number. + backend: Backend to be used. + fn1: Function to be executed in the world. + fn2: Function to be executed in the world leader. Returns: list: List of processes. @@ -226,7 +239,7 @@ async def create_world(world_name, world_size, addr, port, backend, fn1, fn2): global processes for rank in range(world_size): - if rank == 0: + if rank == LEADER_RANK: continue p = mp.Process( target=run_init_world, @@ -237,7 +250,7 @@ async def create_world(world_name, world_size, addr, port, backend, fn1, fn2): processes.append(p) # run leader late - await init_world(world_name, 0, world_size, fn2, backend, addr, port) + await init_world(world_name, LEADER_RANK, world_size, fn2, backend, addr, port) return processes @@ -326,8 +339,8 @@ async def single_host(args): for world_idx in range(1, args.num_workers + 1): pset = await create_world( f"world{world_idx}", - 2, - "127.0.0.1", + WORLD_SIZE, + args.addr, STARTING_PORT + world_idx, args.backend, run, @@ -349,12 +362,12 @@ async def multi_host(args): args: Command line arguments. """ size = int(args.num_workers) - if args.rank == 0: + if args.rank == LEADER_RANK: for world_idx in range(1, size + 1): await init_world( f"world{world_idx}", - 0, - 2, + LEADER_RANK, + WORLD_SIZE, dummy, args.backend, args.addr, @@ -366,7 +379,7 @@ async def multi_host(args): await init_world( f"world{args.rank}", 1, - 2, + WORLD_SIZE, run, args.backend, args.addr,