Skip to content

Commit

Permalink
refactor: refactor resnet example (#67)
Browse files Browse the repository at this point in the history
updated resnet example to use multiworld API

Co-authored-by: Rares Gaia <[email protected]>
  • Loading branch information
raresgaia123 and Rares Gaia authored Aug 8, 2024
1 parent 06598bb commit 53dbf84
Showing 1 changed file with 61 additions and 48 deletions.
109 changes: 61 additions & 48 deletions examples/resnet/m8d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import argparse
import asyncio
import atexit
import os
import time

import torch
Expand Down Expand Up @@ -64,14 +63,16 @@
LEADER_RANK = 0
# Worker rank in every world is going to be 1 because we are creating 2 processes in every world
WORKER_RANK = 1
STARTING_PORT = 29500
WORLD_SIZE = 2


def index_to_class_name(index):
"""
Get class name from index.
Args:
index (int): Index of the class.
index: Index of the class.
"""
return CIFAR10_CLASS_NAMES[index]

Expand All @@ -81,7 +82,10 @@ def load_cifar10(batch_size=1):
Load CIFAR10 dataset.
Args:
batch_size (int): Batch size for the DataLoader.
batch_size: Batch size for the DataLoader.
Returns:
DataLoader instance
"""
transform = transforms.Compose(
[
Expand All @@ -96,35 +100,37 @@ def load_cifar10(batch_size=1):
return cifar10_loader


def dummy(world_name, rank, size, backend):
async def dummy(world_name, rank, size, backend, world_communicator):
"""
Dummy function to be implemented later.
Args:
world_name (str): Name of the world.
rank (int): Rank of the process.
size (int): Number of processes.
backend (str): Backend used for communication.
world_name: Name of the world.
rank: Rank of the process.
size: Number of processes.
backend: Backend used for communication.
"""

print(f"dummy function: world: {world_name}, my rank: {rank}, world size: {size}")
await asyncio.sleep(0)


def run(world_name, rank, size, backend):
async def run(world_name, rank, size, backend, world_communicator):
"""
Distributed function to be implemented later.
Args:
world_name (str): Name of the world.
rank (int): Rank of the process.
size (int): Number of processes.
backend (str): Backend used for communication.
world_name: Name of the world.
rank: Rank of the process.
size: Number of processes.
backend: Backend used for communication.
world_communicator: World communicator
"""
world_idx = int(world_name[5:])

# Initialize ResNet18 model
model = AutoModelForImageClassification.from_pretrained(
"jialicheng/resnet-18-cifar10-21"
"edadaltocg/resnet18_cifar10"
)
model.eval()

Expand All @@ -136,8 +142,12 @@ def run(world_name, rank, size, backend):
image_tensor = (
image_tensor.to(f"cuda:{world_idx}") if backend == "nccl" else image_tensor
)

dist.recv(image_tensor, src=LEADER_RANK)
try:
await world_communicator.recv(image_tensor, LEADER_RANK, world_name)
except Exception as e:
print("RECV FAILED for RANK", rank)
print(f"Caught an except while receiving predicted class: {e}")
break

# Inference
with torch.no_grad():
Expand All @@ -149,13 +159,16 @@ def run(world_name, rank, size, backend):
print(f"Predicted : {predicted}, {predicted.shape}, {type(predicted)}")

# Send the predicted class back to the leader
dist.send(predicted, dst=LEADER_RANK)
try:
await world_communicator.send(predicted, LEADER_RANK, world_name)
except Exception as e:
print(f"Caught an except while sending image: {e}")
break

print(f"Predicted class: {predicted}")


world_manager = None
STARTING_PORT = 29500


async def init_world(
Expand All @@ -165,25 +178,25 @@ async def init_world(
Initialize the distributed environment.
Args:
world_name (str): Name of the world.
rank (int): Rank of the process.
size (int): Number of processes.
fn (function): Function to be executed.
backend (str): Backend to be used.
addr (str): Address of the leader process.
port (int): Port to be used.
world_name: Name of the world.
rank: Rank of the process.
size: Number of processes.
fn: Function to be executed.
backend: Backend to be used.
addr: Address of the leader process.
port: Port to be used.
"""
global world_manager

if world_manager is None:
# TODO: make WorldManager as singleton
world_manager = dist.WorldManager()

world_manager.initialize_world(
await world_manager.initialize_world(
world_name, rank, size, backend=backend, addr=addr, port=port
)

fn(world_name, rank, size, backend)
await fn(world_name, rank, size, backend, world_manager.communicator)


def run_init_world(
Expand All @@ -193,13 +206,13 @@ def run_init_world(
Run the init_world function in a separate process.
Args:
world_name (str): Name of the world.
rank (int): Rank of the process.
size (int): Number of processes.
world_name: Name of the world.
rank: Rank of the process.
size: Number of processes.
fn (function): Function to be executed.
backend (str): Backend to be used.
addr (str): Address of the leader process.
port (int): Port to be used.
backend: Backend to be used.
addr: Address of the leader process.
port: Port to be used.
"""
asyncio.run(init_world(world_name, rank, size, fn, backend, addr, port))

Expand All @@ -212,21 +225,21 @@ async def create_world(world_name, world_size, addr, port, backend, fn1, fn2):
Create a world with the given port and world name.
Args:
world_name (str): Name of the world.
world_size (int): Number of processes in the world.
addr (str): Address of the leader process.
port (int): Port number.
backend (str): Backend to be used.
fn1 (function): Function to be executed in the world.
fn2 (function): Function to be executed in the world leader.
world_name: Name of the world.
world_size: Number of processes in the world.
addr: Address of the leader process.
port: Port number.
backend: Backend to be used.
fn1: Function to be executed in the world.
fn2: Function to be executed in the world leader.
Returns:
list: List of processes.
"""
global processes

for rank in range(world_size):
if rank == 0:
if rank == LEADER_RANK:
continue
p = mp.Process(
target=run_init_world,
Expand All @@ -237,7 +250,7 @@ async def create_world(world_name, world_size, addr, port, backend, fn1, fn2):
processes.append(p)

# run leader late
await init_world(world_name, 0, world_size, fn2, backend, addr, port)
await init_world(world_name, LEADER_RANK, world_size, fn2, backend, addr, port)

return processes

Expand Down Expand Up @@ -326,8 +339,8 @@ async def single_host(args):
for world_idx in range(1, args.num_workers + 1):
pset = await create_world(
f"world{world_idx}",
2,
"127.0.0.1",
WORLD_SIZE,
args.addr,
STARTING_PORT + world_idx,
args.backend,
run,
Expand All @@ -349,12 +362,12 @@ async def multi_host(args):
args: Command line arguments.
"""
size = int(args.num_workers)
if args.rank == 0:
if args.rank == LEADER_RANK:
for world_idx in range(1, size + 1):
await init_world(
f"world{world_idx}",
0,
2,
LEADER_RANK,
WORLD_SIZE,
dummy,
args.backend,
args.addr,
Expand All @@ -366,7 +379,7 @@ async def multi_host(args):
await init_world(
f"world{args.rank}",
1,
2,
WORLD_SIZE,
run,
args.backend,
args.addr,
Expand Down

0 comments on commit 53dbf84

Please sign in to comment.