Skip to content

Commit

Permalink
Remove last references to use_distributed argument (#1353)
Browse files Browse the repository at this point in the history
Co-authored-by: Jack-Khuu <[email protected]>
  • Loading branch information
mreso and Jack-Khuu authored Nov 13, 2024
1 parent 2fcc37c commit 0f58543
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 87 deletions.
74 changes: 1 addition & 73 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,6 @@
import torch._inductor.config
import torch.nn as nn

from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.elastic.utils.distributed import get_free_port

from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama

from torchchat.model import Model, ModelArgs, ModelType

from torchchat.model_config.model_config import resolve_model_config
Expand Down Expand Up @@ -464,77 +458,11 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
return model


def _maybe_init_distributed(
builder_args: BuilderArgs,
) -> Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
"""
Initialize distributed related setups if the user specified
using distributed inference. If not, this is a no-op.
Args:
builder_args (:class:`BuilderArgs`):
Command args for model building.
Returns:
Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
- The first element is an optional DeviceMesh object,
which which describes the mesh topology of devices for the DTensor.
- The second element is an optional ParallelDims object,
which represents the parallel dimensions configuration.
"""
if not builder_args.use_distributed:
return None, None
dist_config = "llama3_8B.toml" # TODO - integrate with chat cmd line

world_mesh, parallel_dims = launch_distributed(dist_config)

assert (
world_mesh is not None and parallel_dims is not None
), f"failed to launch distributed using {dist_config}"

return world_mesh, parallel_dims


def _maybe_parallelize_model(
model: nn.Module,
builder_args: BuilderArgs,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
) -> nn.Module:
"""
We parallelize the module and load the distributed checkpoint to the model
if the user specifies using distributed inference. If not, this is a no-op.
Args:
model (:class:`nn.Module`):
Module to be parallelized.
builder_args (:class:`BuilderArgs`):
Command args for model building.
world_mesh (:class:`DeviceMesh`):
Object which describes the mesh topology
of devices for the DTensor.
parallel_dims (:class:`ParallelDims`):
Object which represents the parallel dimensions configuration.
Returns:
A :class:`nn.Module` object which is parallelized and checkpoint loaded
if the user specifies using distributed inference.
"""
if world_mesh is None:
return model
assert parallel_dims is not None
print("Applying model parallel to model ...")
parallelize_llama(model, world_mesh, parallel_dims)
return load_checkpoints_to_model(model, builder_args, world_mesh)


def _load_model(builder_args: BuilderArgs) -> Model:
# world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
if builder_args.gguf_path:
model = _load_model_gguf(builder_args)
# elif builder_args.use_distributed:
# model = _init_model_on_meta_device(builder_args)
else:
model = _load_model_default(builder_args)
# model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)

if builder_args.dso_path or builder_args.aoti_package_path:
# AOTI-compoiled model will load its own weights.
Expand Down Expand Up @@ -706,4 +634,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
return "TikToken"
if tokenizers:
return "Tokenizers"
return "SentencePiece"
return "SentencePiece"
16 changes: 2 additions & 14 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,13 +915,6 @@ def chat(
]
)
if generator_args.compile:
if (
self.is_speculative and self.builder_args.use_distributed
): # and ("cuda" in builder_args.device):
torch._inductor.config.triton.cudagraph_trees = (
False # Bug with cudagraph trees in this case
)

if self.builder_args.device == "cpu":
if generator_args.max_autotune:
kwargs = {"mode": "max-autotune"}
Expand Down Expand Up @@ -1091,9 +1084,7 @@ def callback(x, *, done_generating=False):

torch._inductor.config.profiler_mark_wrapper_call = True
torch._inductor.config.cpp.enable_kernel_profile = True
if (i != generator_args.num_samples - 1 or not self.profile) or (
self.builder_args.use_distributed and self.rank != 0
):
if i != generator_args.num_samples - 1 or not self.profile:
import contextlib

prof = contextlib.nullcontext()
Expand Down Expand Up @@ -1136,10 +1127,7 @@ def callback(x, *, done_generating=False):
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
else:
print(prof.key_averages().table(sort_by="self_cuda_time_total"))
if self.builder_args.use_distributed:
prof.export_chrome_trace(f"{self.profile}_rank_{self.rank}.json")
else:
prof.export_chrome_trace(f"{self.profile}.json")
prof.export_chrome_trace(f"{self.profile}.json")

if start_pos >= max_seq_length:
print(
Expand Down

0 comments on commit 0f58543

Please sign in to comment.