Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Directly send tensor via jit serialization #3088

Merged
merged 21 commits into from
Dec 13, 2024
Merged

Conversation

ZiyueXu77
Copy link
Collaborator

Fixes # .

Description

Directly send tensor without converting to numpy
Using jit serialization to avoid pickle

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Quick tests passed locally by running ./runtest.sh.
  • In-line docstrings updated.
  • Documentation updated.

@ZiyueXu77 ZiyueXu77 marked this pull request as draft December 2, 2024 21:57
@ZiyueXu77 ZiyueXu77 marked this pull request as ready for review December 3, 2024 22:16
Copy link
Collaborator

@chesterxgchen chesterxgchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the logics seems only for LLM BF16. I think what want to achieve is for all Tensor, regardless or not.

@ZiyueXu77
Copy link
Collaborator Author

the logics seems only for LLM BF16. I think what want to achieve is for all Tensor, regardless or not.

I think we mixed two processes: conversion for filtering, and conversion for communication

Considering local to server communication (reverse will be similar) with quantization:
local model --> to_nvflare_converter --> quant_filter --> (decomposer) --> communication --> (composer) --> dequant_filter --> global

Currently our client api executor has a default to_nvflare_converter as PTtoNumpy, so afterwards everything will be in numpy, including the serialization part, so the tensor decomposer/composer will not be called.

Now if instead of PTtoNumpy, we use a simple "pass through" to_nvflare_converter, it gonna have two indications:

  1. filter needs to handle tensor properly
  2. decomposer and composer will be needed to handle tensor communication, and currently this is again via numpy.

Hence we can have two places with tensor<->numpy conversion: converter for filter, and decomposer for communication. The first will mean all the following computations (filter) are in numpy, while the second means only the communication/serialization is via numpy - but it will be recovered to tensor once received, so "virtually" the whole pipeline is still in tensor.

For the sake of serialization efficency, my guess is that numpy maybe more efficient than jit (@nvidianz to confirm), then jit is only needed for formats not supported by numpy (e.g. bf16), but if otherwise, we can use jit for all cases (And maybe "safe tensor" as suggested).

@chesterxgchen
Copy link
Collaborator

the logics seems only for LLM BF16. I think what want to achieve is for all Tensor, regardless or not.

I think we mixed two processes: conversion for filtering, and conversion for communication

Considering local to server communication (reverse will be similar) with quantization: local model --> to_nvflare_converter --> quant_filter --> (decomposer) --> communication --> (composer) --> dequant_filter --> global

Currently our client api executor has a default to_nvflare_converter as PTtoNumpy, so afterwards everything will be in numpy, including the serialization part, so the tensor decomposer/composer will not be called.

Now if instead of PTtoNumpy, we use a simple "pass through" to_nvflare_converter, it gonna have two indications:

  1. filter needs to handle tensor properly
  2. decomposer and composer will be needed to handle tensor communication, and currently this is again via numpy.

Hence we can have two places with tensor<->numpy conversion: converter for filter, and decomposer for communication. The first will mean all the following computations (filter) are in numpy, while the second means only the communication/serialization is via numpy - but it will be recovered to tensor once received, so "virtually" the whole pipeline is still in tensor.

For the sake of serialization efficency, my guess is that numpy maybe more efficient than jit (@nvidianz to confirm), then jit is only needed for formats not supported by numpy (e.g. bf16), but if otherwise, we can use jit for all cases (And maybe "safe tensor" as suggested).

The reason for avoid to_numpy() conversion is to avoid loss the Tensor Compression ratio to make sure the Tensor Model doesn't increase after transfer. It doesn't matter this conversion is in filter or other places, if we convert tensor in jit in place, and use to_n umpy in another place before sending over the wire, we already loss the compression, it JIT conversion is becomes pointless.

We need Tensor native in all communication pipeline

@ZiyueXu77
Copy link
Collaborator Author

ZiyueXu77 commented Dec 4, 2024

the logics seems only for LLM BF16. I think what want to achieve is for all Tensor, regardless or not.

I think we mixed two processes: conversion for filtering, and conversion for communication
Considering local to server communication (reverse will be similar) with quantization: local model --> to_nvflare_converter --> quant_filter --> (decomposer) --> communication --> (composer) --> dequant_filter --> global
Currently our client api executor has a default to_nvflare_converter as PTtoNumpy, so afterwards everything will be in numpy, including the serialization part, so the tensor decomposer/composer will not be called.
Now if instead of PTtoNumpy, we use a simple "pass through" to_nvflare_converter, it gonna have two indications:

  1. filter needs to handle tensor properly
  2. decomposer and composer will be needed to handle tensor communication, and currently this is again via numpy.

Hence we can have two places with tensor<->numpy conversion: converter for filter, and decomposer for communication. The first will mean all the following computations (filter) are in numpy, while the second means only the communication/serialization is via numpy - but it will be recovered to tensor once received, so "virtually" the whole pipeline is still in tensor.
For the sake of serialization efficency, my guess is that numpy maybe more efficient than jit (@nvidianz to confirm), then jit is only needed for formats not supported by numpy (e.g. bf16), but if otherwise, we can use jit for all cases (And maybe "safe tensor" as suggested).

The reason for avoid to_numpy() conversion is to avoid loss the Tensor Compression ratio to make sure the Tensor Model doesn't increase after transfer. It doesn't matter this conversion is in filter or other places, if we convert tensor in jit in place, and use to_n umpy in another place before sending over the wire, we already loss the compression, it JIT conversion is becomes pointless.

We need Tensor native in all communication pipeline

no this is not the case, use numpy + jit for serialization will not lead to bigger message, only the conversion for filter purpose will - because for that we want everything in numpy and so have to cast bf16 to float32 so that it can be convered

nvidianz
nvidianz previously approved these changes Dec 11, 2024
@ZiyueXu77
Copy link
Collaborator Author

/build

@ZiyueXu77 ZiyueXu77 enabled auto-merge (squash) December 11, 2024 21:42
nvidianz
nvidianz previously approved these changes Dec 13, 2024
@ZiyueXu77
Copy link
Collaborator Author

/build

Copy link
Collaborator

@chesterxgchen chesterxgchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to change package path

@ZiyueXu77
Copy link
Collaborator Author

/build

Copy link
Collaborator

@chesterxgchen chesterxgchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZiyueXu77 ZiyueXu77 merged commit 38157c3 into NVIDIA:main Dec 13, 2024
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants