Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: refactor resnet example #67

Merged
merged 1 commit into from
Aug 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 61 additions & 48 deletions examples/resnet/m8d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import argparse
import asyncio
import atexit
import os
import time

import torch
Expand Down Expand Up @@ -64,14 +63,16 @@
LEADER_RANK = 0
# Worker rank in every world is going to be 1 because we are creating 2 processes in every world
WORKER_RANK = 1
STARTING_PORT = 29500
WORLD_SIZE = 2


def index_to_class_name(index):
"""
Get class name from index.

Args:
index (int): Index of the class.
index: Index of the class.
"""
return CIFAR10_CLASS_NAMES[index]

Expand All @@ -81,7 +82,10 @@ def load_cifar10(batch_size=1):
Load CIFAR10 dataset.

Args:
batch_size (int): Batch size for the DataLoader.
batch_size: Batch size for the DataLoader.

Returns:
DataLoader instance
"""
transform = transforms.Compose(
[
Expand All @@ -96,35 +100,37 @@ def load_cifar10(batch_size=1):
return cifar10_loader


def dummy(world_name, rank, size, backend):
async def dummy(world_name, rank, size, backend, world_communicator):
"""
Dummy function to be implemented later.

Args:
world_name (str): Name of the world.
rank (int): Rank of the process.
size (int): Number of processes.
backend (str): Backend used for communication.
world_name: Name of the world.
rank: Rank of the process.
size: Number of processes.
backend: Backend used for communication.
"""

print(f"dummy function: world: {world_name}, my rank: {rank}, world size: {size}")
await asyncio.sleep(0)


def run(world_name, rank, size, backend):
async def run(world_name, rank, size, backend, world_communicator):
"""
Distributed function to be implemented later.

Args:
world_name (str): Name of the world.
rank (int): Rank of the process.
size (int): Number of processes.
backend (str): Backend used for communication.
world_name: Name of the world.
rank: Rank of the process.
size: Number of processes.
backend: Backend used for communication.
world_communicator: World communicator
"""
world_idx = int(world_name[5:])

# Initialize ResNet18 model
model = AutoModelForImageClassification.from_pretrained(
"jialicheng/resnet-18-cifar10-21"
"edadaltocg/resnet18_cifar10"
)
model.eval()

Expand All @@ -136,8 +142,12 @@ def run(world_name, rank, size, backend):
image_tensor = (
image_tensor.to(f"cuda:{world_idx}") if backend == "nccl" else image_tensor
)

dist.recv(image_tensor, src=LEADER_RANK)
try:
await world_communicator.recv(image_tensor, LEADER_RANK, world_name)
except Exception as e:
print("RECV FAILED for RANK", rank)
print(f"Caught an except while receiving predicted class: {e}")
break

# Inference
with torch.no_grad():
Expand All @@ -149,13 +159,16 @@ def run(world_name, rank, size, backend):
print(f"Predicted : {predicted}, {predicted.shape}, {type(predicted)}")

# Send the predicted class back to the leader
dist.send(predicted, dst=LEADER_RANK)
try:
await world_communicator.send(predicted, LEADER_RANK, world_name)
except Exception as e:
print(f"Caught an except while sending image: {e}")
break

print(f"Predicted class: {predicted}")


world_manager = None
STARTING_PORT = 29500


async def init_world(
Expand All @@ -165,25 +178,25 @@ async def init_world(
Initialize the distributed environment.

Args:
world_name (str): Name of the world.
rank (int): Rank of the process.
size (int): Number of processes.
fn (function): Function to be executed.
backend (str): Backend to be used.
addr (str): Address of the leader process.
port (int): Port to be used.
world_name: Name of the world.
rank: Rank of the process.
size: Number of processes.
fn: Function to be executed.
backend: Backend to be used.
addr: Address of the leader process.
port: Port to be used.
"""
global world_manager

if world_manager is None:
# TODO: make WorldManager as singleton
world_manager = dist.WorldManager()

world_manager.initialize_world(
await world_manager.initialize_world(
world_name, rank, size, backend=backend, addr=addr, port=port
)

fn(world_name, rank, size, backend)
await fn(world_name, rank, size, backend, world_manager.communicator)


def run_init_world(
Expand All @@ -193,13 +206,13 @@ def run_init_world(
Run the init_world function in a separate process.

Args:
world_name (str): Name of the world.
rank (int): Rank of the process.
size (int): Number of processes.
world_name: Name of the world.
rank: Rank of the process.
size: Number of processes.
fn (function): Function to be executed.
backend (str): Backend to be used.
addr (str): Address of the leader process.
port (int): Port to be used.
backend: Backend to be used.
addr: Address of the leader process.
port: Port to be used.
"""
asyncio.run(init_world(world_name, rank, size, fn, backend, addr, port))

Expand All @@ -212,21 +225,21 @@ async def create_world(world_name, world_size, addr, port, backend, fn1, fn2):
Create a world with the given port and world name.

Args:
world_name (str): Name of the world.
world_size (int): Number of processes in the world.
addr (str): Address of the leader process.
port (int): Port number.
backend (str): Backend to be used.
fn1 (function): Function to be executed in the world.
fn2 (function): Function to be executed in the world leader.
world_name: Name of the world.
world_size: Number of processes in the world.
addr: Address of the leader process.
port: Port number.
backend: Backend to be used.
fn1: Function to be executed in the world.
fn2: Function to be executed in the world leader.

Returns:
list: List of processes.
"""
global processes

for rank in range(world_size):
if rank == 0:
if rank == LEADER_RANK:
continue
p = mp.Process(
target=run_init_world,
Expand All @@ -237,7 +250,7 @@ async def create_world(world_name, world_size, addr, port, backend, fn1, fn2):
processes.append(p)

# run leader late
await init_world(world_name, 0, world_size, fn2, backend, addr, port)
await init_world(world_name, LEADER_RANK, world_size, fn2, backend, addr, port)

return processes

Expand Down Expand Up @@ -326,8 +339,8 @@ async def single_host(args):
for world_idx in range(1, args.num_workers + 1):
pset = await create_world(
f"world{world_idx}",
2,
"127.0.0.1",
WORLD_SIZE,
args.addr,
STARTING_PORT + world_idx,
args.backend,
run,
Expand All @@ -349,12 +362,12 @@ async def multi_host(args):
args: Command line arguments.
"""
size = int(args.num_workers)
if args.rank == 0:
if args.rank == LEADER_RANK:
for world_idx in range(1, size + 1):
await init_world(
f"world{world_idx}",
0,
2,
LEADER_RANK,
WORLD_SIZE,
dummy,
args.backend,
args.addr,
Expand All @@ -366,7 +379,7 @@ async def multi_host(args):
await init_world(
f"world{args.rank}",
1,
2,
WORLD_SIZE,
run,
args.backend,
args.addr,
Expand Down