Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][distributed] simplify code to support pipeline parallel #6406

Merged
merged 11 commits into from
Jul 15, 2024

Conversation

youkaichao
Copy link
Member

to minimize the line of code change for a model to support pipeline parallel.

in the model weight loading part, just add:

                if name not in params_dict:
                    # in pipeline parallelism, we may have layers that are not
                    # present on this rank
                    continue

(the benefit is we don't need to introduce extra indentation of the try-except code)

in the layer construction part, just add:

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda: LlamaDecoderLayer(config=config,
                                      cache_config=cache_config,
                                      quant_config=quant_config))

hopefully, with these change, code for a model to support pp will be easier, and the review can also be easier.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only trigger fastcheck CI to run, which consists only a small and essential subset of tests to quickly catch errors with the flexibility to run extra individual tests on top (you can do this by unblocking test steps in the Buildkite run).

Full CI run is still required to merge this PR so once the PR is ready to go, please make sure to run it. If you need all test signals in between PR commits, you can trigger full CI as well.

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@DarkLight1337
Copy link
Member

DarkLight1337 commented Jul 13, 2024

Perhaps it would be better to move the make_layers function into the utils file inside models/?

except KeyError:
pass

if name not in params_dict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could silence some hard to track down bugs when loading more complex state dicts (in quantized case)

Could we try to check if the name corresponds to a layer not on the device b/c of PP?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is quite difficult, as load_weights does not have layer information.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

come up with a workaround in 61fa242. PTAL!

@youkaichao
Copy link
Member Author

Perhaps it would be better to move the make_layers function into the utils file inside models/?

fixed in 347399e

@andoorve
Copy link
Collaborator

andoorve commented Jul 14, 2024

I noticed CUDA out of memory error on the basic correctness tests here. Is it reproducible locally? I think this change shouldn't cause that so it's possibly a flaky test?

@youkaichao
Copy link
Member Author

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 14, 2024
@youkaichao
Copy link
Member Author

@andoorve finally figured it out, it is because lru cache stores a reference of the model, and then fails the gc system :(

@youkaichao
Copy link
Member Author

merge first to unblock the following models support.

@youkaichao youkaichao merged commit 69672f1 into vllm-project:main Jul 15, 2024
73 checks passed
@youkaichao youkaichao deleted the simplify_pp branch July 15, 2024 04:20
def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
"""Check if a parameter is missing in a pipeline parallel model."""
for missing_layer_name in get_pp_missing_layer_names(model):
if name.startswith(missing_layer_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:-) "xx.11".startswith("xx.1")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for the bug, and thanks for pointing it out so quickly! please take a look at #6446 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants