-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhanced RNG State Management with Index-Based Control for Graph-Safe Tensor Parallelism #58859
Conversation
|
||
x = generate_random_number_and_states(gen) | ||
|
||
assert_allclose = lambda x, y: np.testing.assert_allclose( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not array_equal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Referenced from test_cuda_random_seed, I believe we should use array_allcose in this context as it involves a float.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked the first case test_gen_dropout_dygraph
in test_cuda_random_seed, it is also ok if the allclose
is changed to equal
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed into assert_array_equal
7fced59
to
a5c3d11
Compare
Hello, I require aprroval for change from one of the following individuals: RD approval for API change: @XiaoguangHu01 , @jeff41404 , @lanxianghit or @qingqing01 |
@eee4017 LGTM for this API naming and semantic, I noticed that you have suggested a similar feature in the PyTorch community: pytorch/pytorch#113541. Perhaps we should wait for feedback from the PyTorch community. If the proposed API is compatible, it could greatly benefit both community. If this PR is blocking our ongoing project, please let me know, as a temporary solution, we can merge this PR now and make adjustments to the API naming and semantics later, before our next release. |
Thank you for reviewing the proposed API change. If this alteration is acceptable, I recommend proceeding with the merger of this PR. This action would simplify our process by minimizing the need for extensive communication and the time spent on re-running CI tests. It's important to note that the implementation of the RNG in PyTorch differs from that in Paddle. Consequently, that issue may require a thorough review and discussion. The timeframe for that is uncertain. I will keep you updated with any progress or new insights from the PyTorch community. In the meantime, moving forward with the current PR would be a practical step to maintain momentum in our project. Thank you. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
python/paddle/__init__.py
Outdated
@@ -811,6 +812,7 @@ | |||
'unique_consecutive', | |||
'set_cuda_rng_state', | |||
'set_rng_state', | |||
'register_rng_state_as_index', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
register_rng_state_as_index
should be public API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move these APIs to incubate
5058691
to
2772e3d
Compare
2772e3d
to
c0eb546
Compare
c0eb546
to
9f0789b
Compare
… Tensor Parallelism (PaddlePaddle#58859) * allow multiple rng state in generator * fix get_rng_state
…aph-Safe Tensor Parallelism (PaddlePaddle#58859)" This reverts commit 3bcdeef.
…aph-Safe Tensor Parallelism (PaddlePaddle#58859)" This reverts commit 3bcdeef.
VLOG(4) << "initial seed: " << this->state_.current_seed | ||
<< ", cpu engine: " << &this->state_.cpu_engine; | ||
current_index = states_.size(); | ||
states_.emplace_back(seed); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be states_.emplace_back(-1, seed)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it should be states_.emplace_back(-1, seed)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in #60310.
seed = new_seed; | ||
} | ||
|
||
GeneratorState clone() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we need a clone function? Why cannot we just write the deep-copy constructor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we can use copy constructor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in #60310.
int64_t device; | ||
uint64_t seed; | ||
uint64_t offset; | ||
std::shared_ptr<std::mt19937_64> cpu_engine; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we need std::shared_ptr<std::mt19937_64>
instead of std::mt19937_64
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This design choice is highlighted in the original code, which can be reviewed here. The code snippet demonstrates an interesting pattern.
auto seed = GetRandomSeed();
std::seed_seq seq({seed});
auto engine = std::make_shared<std::mt19937_64>(seq);
this->state_.cpu_engine = *engine;
this->engine_ = engine;
An mt19937_64
engine is allocated on the heap, and GeneratorState
retains a copy of this initial state. However, this copy is subsequently unused. Instead, Generator
has a shared_ptr
pointing to the mt19937_64
engine on the heap, and this pointer is used for engine manipulation, as seen in GetCPUEngine
and Random64
.
The original code seems to favor the mt19937_64
engine on the heap managed by the shared_ptr rather than the engine held by GeneratorState
. If GeneratorState
were to own the mt19937_64
engine directly, a shared_ptr would not be necessary since the engine's lifetime would align with the state's lifetime. In this case, a normal pointer would suffice.
I have considered letting GeneratorState
directly own the mt19937_64
engine. However, this would entail changing the GetCPUEngine
API to a normal pointer, potentially impacting numerous files that utilize this API.
uint64_t Generator::RegisterStateIndex(const GeneratorState& state) { | ||
std::lock_guard<std::mutex> lock(mu_); | ||
auto new_index = states_.size(); | ||
states_.push_back(state); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we use states_.push_back(state.clone())
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, cloning would be the more appropriate approach in this context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in #60310.
return this->state_.current_seed; | ||
std::lock_guard<std::mutex> lock(mu_); | ||
uint64_t seed = GetRandomSeed(); | ||
state().reset(seed); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code logic here is different from the original ones. The original one does not reset this.state_.offset
but the new one reset state().offset
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reset of offset in the new implementation is an intentional change. When seed() is invoked, it implies the generation of a new random value for the seed and setting the current seed to this random value. Consequently, every time a seed is set, it is logical and necessary to reset the value of offset as well. This approach aligns with SetCurrentSeed
in our repository and the implementation in PyTorch here, where the offset is reset whenever a new seed is set.
orig_rng_state = paddle.get_rng_state() | ||
orig_rng_state_index = paddle.incubate.get_rng_state(use_index=True) | ||
# register a new state and set that state with the seed, store the indices into states_ | ||
self.states_[name] = paddle.incubate.register_rng_state_as_index() | ||
paddle.seed(seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we reorder paddle.seed()
and self.states_[name] = ...
in the wrong order here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the order of operations is correct as implemented. Here's a breakdown of the process:
Original Code:
orig_rng_state = paddle.get_rng_state()
paddle.seed(seed)
self.states_[name] = paddle.get_rng_state()
paddle.set_rng_state(orig_rng_state)
In this snippet, we first save the original RNG state. Then, we set a new seed and update self.states_[name]
with the new RNG state. Finally, we restore the original RNG state.
Revised Code:
orig_rng_state_index = paddle.incubate.get_rng_state(use_index=True)
# register a new state and set that state with the seed, store the indices into states_
self.states_[name] = paddle.incubate.register_rng_state_as_index()
paddle.seed(seed)
paddle.incubate.set_rng_state(orig_rng_state_index, use_index=True)
In the revised code, we first obtain the index of the original RNG state. We then register a new RNG state and update self.states_[name]
with the index of this new state. Then, we set the seed for the new state. Finally, we switch back to the old state.
PR types
New features
PR changes
APIs
Description
This Pull Request presents a refined approach to managing multiple RNG states within the Generator class. The primary goal is to facilitate a more robust control mechanism over RNG states to support advanced computational scenarios such as CUDA graphs and tensor parallelism. We introduced an index-based state management system, extended the prior method of direct state manipulation. This method registers states and controls them through indices, ensuring graph-safe RNG operations.
Background
Current RNG state management, especially in the context of CUDA graphs, faces challenges with updating the RNG state(a combination of seed and offset). The offset increments with each use, but its update within CUDA graphs is problematic because it is a host state and not captured in the CUDA graph's operation sequence. This can be demonstrated by the following code where the offset increment is not recorded in the CUDA graph:
As proposed in #58310 , PaddlePaddle's solution involves wrapping offset updates within callback functions that are invoked before replaying the CUDA graph. This manually recalculates and updates the offset for each node, as shown here:
Here's an explanation of how this mechanism is implemented through code:
Challenges in TP/PP
In Tensor Parallelism (TP) and Pipeline Parallelism (PP), RNG state exchanges are necessary, involving saving and restoring states to Python. This presents difficulties as the state changes in CUDA graphs cannot be captured with callback, as illustrated below:
Also, PyTorch also faces same limitations with set_rng_state/get_rng_state in TP contexts under CUDA Graph.
Solution
To address these issues, we propose new APIs in this PR that allow RNG state management via indices, rather than direct state manipulation. This enables capturing the state behavior during CUDA graph operations. Importantly, while we are advancing the system with these new functionalities, the pre-existing
set_state
andget_state
APIs will be remained to ensure backward compatibility.GetStateIndex
/SetStateIndex
to switch between RNG states.RegisterStateIndex
for registering new RNG states and switching to them.register_rng_state_as_index
for encapsulating and indexing the current RNG state.get_rng_state
/set_rng_state
modified to include an use_index flag for index-based control.RNGStatesTracker
to accommodate the new index-based system, ensuring compatibility with tensor parallelism and CUDA graphs.With these APIs, we can maintain the state_index of the generator, allowing for graph-safe manipulation and state tracking captured in callback functions: