Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove last references to use_distributed argument #1353

Merged
merged 3 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 1 addition & 73 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,6 @@
import torch._inductor.config
import torch.nn as nn

from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.elastic.utils.distributed import get_free_port

from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama

from torchchat.model import Model, ModelArgs, ModelType

from torchchat.model_config.model_config import resolve_model_config
Expand Down Expand Up @@ -464,77 +458,11 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
return model


def _maybe_init_distributed(
builder_args: BuilderArgs,
) -> Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
"""
Initialize distributed related setups if the user specified
using distributed inference. If not, this is a no-op.
Args:
builder_args (:class:`BuilderArgs`):
Command args for model building.
Returns:
Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
- The first element is an optional DeviceMesh object,
which which describes the mesh topology of devices for the DTensor.
- The second element is an optional ParallelDims object,
which represents the parallel dimensions configuration.
"""
if not builder_args.use_distributed:
return None, None
dist_config = "llama3_8B.toml" # TODO - integrate with chat cmd line

world_mesh, parallel_dims = launch_distributed(dist_config)

assert (
world_mesh is not None and parallel_dims is not None
), f"failed to launch distributed using {dist_config}"

return world_mesh, parallel_dims


def _maybe_parallelize_model(
model: nn.Module,
builder_args: BuilderArgs,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
) -> nn.Module:
"""
We parallelize the module and load the distributed checkpoint to the model
if the user specifies using distributed inference. If not, this is a no-op.
Args:
model (:class:`nn.Module`):
Module to be parallelized.
builder_args (:class:`BuilderArgs`):
Command args for model building.
world_mesh (:class:`DeviceMesh`):
Object which describes the mesh topology
of devices for the DTensor.
parallel_dims (:class:`ParallelDims`):
Object which represents the parallel dimensions configuration.
Returns:
A :class:`nn.Module` object which is parallelized and checkpoint loaded
if the user specifies using distributed inference.
"""
if world_mesh is None:
return model
assert parallel_dims is not None
print("Applying model parallel to model ...")
parallelize_llama(model, world_mesh, parallel_dims)
return load_checkpoints_to_model(model, builder_args, world_mesh)


def _load_model(builder_args: BuilderArgs) -> Model:
# world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
if builder_args.gguf_path:
model = _load_model_gguf(builder_args)
# elif builder_args.use_distributed:
# model = _init_model_on_meta_device(builder_args)
else:
model = _load_model_default(builder_args)
# model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)

if builder_args.dso_path or builder_args.aoti_package_path:
# AOTI-compoiled model will load its own weights.
Expand Down Expand Up @@ -706,4 +634,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
return "TikToken"
if tokenizers:
return "Tokenizers"
return "SentencePiece"
return "SentencePiece"
16 changes: 2 additions & 14 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,13 +915,6 @@ def chat(
]
)
if generator_args.compile:
if (
self.is_speculative and self.builder_args.use_distributed
): # and ("cuda" in builder_args.device):
torch._inductor.config.triton.cudagraph_trees = (
False # Bug with cudagraph trees in this case
)

if self.builder_args.device == "cpu":
if generator_args.max_autotune:
kwargs = {"mode": "max-autotune"}
Expand Down Expand Up @@ -1091,9 +1084,7 @@ def callback(x, *, done_generating=False):

torch._inductor.config.profiler_mark_wrapper_call = True
torch._inductor.config.cpp.enable_kernel_profile = True
if (i != generator_args.num_samples - 1 or not self.profile) or (
self.builder_args.use_distributed and self.rank != 0
):
if i != generator_args.num_samples - 1 or not self.profile:
import contextlib

prof = contextlib.nullcontext()
Expand Down Expand Up @@ -1136,10 +1127,7 @@ def callback(x, *, done_generating=False):
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
else:
print(prof.key_averages().table(sort_by="self_cuda_time_total"))
if self.builder_args.use_distributed:
prof.export_chrome_trace(f"{self.profile}_rank_{self.rank}.json")
else:
prof.export_chrome_trace(f"{self.profile}.json")
prof.export_chrome_trace(f"{self.profile}.json")

if start_pos >= max_seq_length:
print(
Expand Down
Loading