Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] veScale: High-Level API for nD Parallel Training #39

Open
leonardo0lyj opened this issue Jun 12, 2024 · 0 comments
Open

[RFC] veScale: High-Level API for nD Parallel Training #39

leonardo0lyj opened this issue Jun 12, 2024 · 0 comments
Assignees
Labels
rfc Let's discuss a proposal

Comments

@leonardo0lyj
Copy link
Collaborator

leonardo0lyj commented Jun 12, 2024

TL'DR

tldr

Motivation

Our current APIs for nD Parallel Training are low-level and are kind of complex for common users ... Ideally, we want a simpler API at a high level like this:

Single Device Code

dataset = ...
data_loader = torch.utils.data.DataLoader(dataset, ...)

class Net(nn.Module):
    ...

def optimizer_fn(model):
    ...
    return torch.optim.Adam(model_param_groups, ...)

def lr_scheduler_fn(optimizer):
    ...
    return torch.optim.lr_scheduler.StepLR(optimizer, ...)

model = Net(...)
optimizer = optimizer_fn(model)
scheduler = lr_scheduler_fn(optimizer)

for epoch in range(10):
    for batch in data_loader:
        optimizer.zero_grad()
        loss = model(batch)
        loss.backward()
        optimizer.step()
    scheduler.step()

torch.save(model.state_dict(), "/path/to/checkpoint")
torch.save(optimizer.state_dict(), "/path/to/checkpoint")
torch.save(scheduler.state_dict(), "/path/to/checkpoint")

veScale High-Level API for nD Parallel Training

dataset = ...

### zero code change on model
class Net(nn.Module):
    ...

def optimizer_fn(model):
    ...
    return torch.optim.Adam(model_param_groups, ...)

def lr_scheduler_fn(optimizer):
    ...
    return torch.optim.lr_scheduler.StepLR(optimizer, ...)

### create giant model without OOM
model = vescale.deferred_init(Net, ...)

### generate plan of nD parallel training under user constraints
# $ constraints = { "pipeline_parallel.split_method" : "flops",
# $                 "tensor_parallel.sharding_policy" : "megatron"  }
plan = vescale.generate_plan(constraints, model)
# $ print(plan)
# $   pipeline_parallel.split_points : ["layer1", "layer3", ...]
# $   tensor_parallel.sharding_plan : { "layer2.weight" : [Shard(dim=0)], ... }

### create nD parallel model and optimizer, specified by the plan
model, optimizer, scheduler, data_loader = vescale.parallelize(plan, model, optimizer_fn, lr_scheduler_fn, dataset)

### zero code change on training loop
for epoch in range(10):
    for batch in data_loader:
        optimizer.zero_grad()
        ### trains nD parallel model as if on single device
        loss = model(batch)
        loss.backward()
        optimizer.step()
    scheduler.step()
 
### saves nD parallel model and optimizer
vescale.save("/path/to/checkpoint", { "plan": plan, "model" : model, "optimizer" : optimizer, "lr_scheduler": scheduler })

Idea

  • Single Device Abstraction for nD Parallel Training
  • Common users can only see this high-level API
  • Common users can only write <10 LoC in training scripts
  • veScale handles all complexities under the hood (e.g., all low-level APIs)
  • This is a unified API for nD parallelsim
    • Take an nD parallel "Plan" (a.k.a. config)
    • Create nD DeviceMesh
    • Create each D of nD Parallel
    • Support both Eager and Compile Mode
  • This is a future-proof API
    • can extend to future DeviceMesh
    • can extend to future parallel (e.g., EP, CP, *P)

Feedbacks are all we need : )

(image source: 1 2)

@leonardo0lyj leonardo0lyj self-assigned this Jun 12, 2024
@liwenchangbdbz liwenchangbdbz added the rfc Let's discuss a proposal label Jun 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
rfc Let's discuss a proposal
Projects
None yet
Development

No branches or pull requests

2 participants