Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhanced RNG State Management with Index-Based Control for Graph-Safe Tensor Parallelism #58859

Conversation

eee4017
Copy link
Contributor

@eee4017 eee4017 commented Nov 9, 2023

PR types

New features

PR changes

APIs

Description

This Pull Request presents a refined approach to managing multiple RNG states within the Generator class. The primary goal is to facilitate a more robust control mechanism over RNG states to support advanced computational scenarios such as CUDA graphs and tensor parallelism. We introduced an index-based state management system, extended the prior method of direct state manipulation. This method registers states and controls them through indices, ensuring graph-safe RNG operations.

Background

Current RNG state management, especially in the context of CUDA graphs, faces challenges with updating the RNG state(a combination of seed and offset). The offset increments with each use, but its update within CUDA graphs is problematic because it is a host state and not captured in the CUDA graph's operation sequence. This can be demonstrated by the following code where the offset increment is not recorded in the CUDA graph:

int offset = 0; 

// cudaGraph begin
Kernel0<<<...>>>(param0, param1, param2, …)
offset = offset + 1; // This is not recorded 
Kernel1<<<...>>>(param0, offset, param2, …)
Kernel2<<<...>>>(param0, param1, param2, …)
offset = offset + 1; // This is not recorded 
Kernel3<<<...>>>(offset, param1, param2, …)
// cudaGraph end

As proposed in #58310 , PaddlePaddle's solution involves wrapping offset updates within callback functions that are invoked before replaying the CUDA graph. This manually recalculates and updates the offset for each node, as shown here:

Here's an explanation of how this mechanism is implemented through code:

int offset;

int *offset_ptr = &offset;

// This callback function increments the RNG state and returns the updated value.
auto callback = [offset_ptr](){
	*offset_ptr = *offset_ptr + 1;
	return *offset_ptr;
};

rng_state = callback();
cuGraphKernelNodeSetParams(Kernel1_node, rng_state);
rng_state = callback();
cuGraphKernelNodeSetParams(Kernel3_node, rng_state);
cudaGraph.replay() // Kernel0, Kernel1, Kernel2, Kernel3

Challenges in TP/PP

In Tensor Parallelism (TP) and Pipeline Parallelism (PP), RNG state exchanges are necessary, involving saving and restoring states to Python. This presents difficulties as the state changes in CUDA graphs cannot be captured with callback, as illustrated below:

int offset;

// cudaGraph begin

// RNGStatesTracker.__enter__
self.states_[name] = paddle.get_rng_state() // save current rng state to a python dict
paddle.set_rng_state(...) // set new random state, offset is set to new value

offset = offset + 1;
Kernel1<<<...>>>(param0, offset, param2, …)

// RNGStatesTracker.__exit__
paddle.set_rng_state(...) // restore old random state

// cudaGraph end

Also, PyTorch also faces same limitations with set_rng_state/get_rng_state in TP contexts under CUDA Graph.

Solution

To address these issues, we propose new APIs in this PR that allow RNG state management via indices, rather than direct state manipulation. This enables capturing the state behavior during CUDA graph operations. Importantly, while we are advancing the system with these new functionalities, the pre-existing set_state and get_state APIs will be remained to ensure backward compatibility.

  • C++ Generator API:
    • GetStateIndex/SetStateIndex to switch between RNG states.
    • RegisterStateIndex for registering new RNG states and switching to them.
    • Refactoring of the generator's constructor and methods for better readability.
    • Dropout layer updates for graph-safe tensor parallelism using the correct RNG state index.
  • Python API:
    • register_rng_state_as_index for encapsulating and indexing the current RNG state.
    • get_rng_state/set_rng_state modified to include an use_index flag for index-based control.
    • Modification of the RNGStatesTracker to accommodate the new index-based system, ensuring compatibility with tensor parallelism and CUDA graphs.

With these APIs, we can maintain the state_index of the generator, allowing for graph-safe manipulation and state tracking captured in callback functions:

uint64_t state_index = GetStateIndex(); 
[state_index](){ 
	gen->SetStateIndex(state_index)
	// call generator to increment offset
}

@jeng1220 jeng1220 added the NVIDIA label Nov 9, 2023
@paddle-bot paddle-bot bot added the contributor External developers label Nov 9, 2023

x = generate_random_number_and_states(gen)

assert_allclose = lambda x, y: np.testing.assert_allclose(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not array_equal?

Copy link
Contributor Author

@eee4017 eee4017 Nov 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Referenced from test_cuda_random_seed, I believe we should use array_allcose in this context as it involves a float.

Copy link
Contributor

@zhiqiu zhiqiu Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the first case test_gen_dropout_dygraph in test_cuda_random_seed, it is also ok if the allclose is changed to equal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed into assert_array_equal

@eee4017 eee4017 force-pushed the enhanced_rng_state_management_with_index_based_control_for_graph_safe_tensor_parallelism branch from 7fced59 to a5c3d11 Compare November 15, 2023 00:37
@eee4017
Copy link
Contributor Author

eee4017 commented Nov 16, 2023

Hello, I require aprroval for change from one of the following individuals:

RD approval for API change: @XiaoguangHu01 , @jeff41404 , @lanxianghit or @qingqing01
TPM approval for API change and API documents change: @jzhang533 (ZhangJun), @sunzhongkai588 (SunZhongKai), @Ligoml (LiMengLiu)

@jzhang533
Copy link
Contributor

@eee4017 LGTM for this API naming and semantic, I noticed that you have suggested a similar feature in the PyTorch community: pytorch/pytorch#113541.

Perhaps we should wait for feedback from the PyTorch community. If the proposed API is compatible, it could greatly benefit both community.

If this PR is blocking our ongoing project, please let me know, as a temporary solution, we can merge this PR now and make adjustments to the API naming and semantics later, before our next release.

@eee4017
Copy link
Contributor Author

eee4017 commented Nov 16, 2023

Thank you for reviewing the proposed API change. If this alteration is acceptable, I recommend proceeding with the merger of this PR. This action would simplify our process by minimizing the need for extensive communication and the time spent on re-running CI tests. It's important to note that the implementation of the RNG in PyTorch differs from that in Paddle. Consequently, that issue may require a thorough review and discussion. The timeframe for that is uncertain. I will keep you updated with any progress or new insights from the PyTorch community. In the meantime, moving forward with the current PR would be a practical step to maintain momentum in our project. Thank you.

jzhang533
jzhang533 previously approved these changes Nov 16, 2023
Copy link
Contributor

@jzhang533 jzhang533 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -811,6 +812,7 @@
'unique_consecutive',
'set_cuda_rng_state',
'set_rng_state',
'register_rng_state_as_index',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

register_rng_state_as_index should be public API?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move these APIs to incubate

@eee4017 eee4017 force-pushed the enhanced_rng_state_management_with_index_based_control_for_graph_safe_tensor_parallelism branch 3 times, most recently from 5058691 to 2772e3d Compare November 28, 2023 06:34
@eee4017 eee4017 force-pushed the enhanced_rng_state_management_with_index_based_control_for_graph_safe_tensor_parallelism branch from 2772e3d to c0eb546 Compare November 30, 2023 04:17
risemeup1
risemeup1 previously approved these changes Dec 2, 2023
@eee4017 eee4017 force-pushed the enhanced_rng_state_management_with_index_based_control_for_graph_safe_tensor_parallelism branch from c0eb546 to 9f0789b Compare December 4, 2023 04:12
@eee4017 eee4017 requested a review from risemeup1 December 4, 2023 05:06
@zyfncg zyfncg merged commit 3bcdeef into PaddlePaddle:develop Dec 5, 2023
SigureMo pushed a commit to gouzil/Paddle that referenced this pull request Dec 5, 2023
… Tensor Parallelism (PaddlePaddle#58859)

* allow multiple rng state in generator

* fix get_rng_state
VLOG(4) << "initial seed: " << this->state_.current_seed
<< ", cpu engine: " << &this->state_.cpu_engine;
current_index = states_.size();
states_.emplace_back(seed);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be states_.emplace_back(-1, seed)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should be states_.emplace_back(-1, seed)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #60310.

seed = new_seed;
}

GeneratorState clone() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need a clone function? Why cannot we just write the deep-copy constructor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can use copy constructor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #60310.

int64_t device;
uint64_t seed;
uint64_t offset;
std::shared_ptr<std::mt19937_64> cpu_engine;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need std::shared_ptr<std::mt19937_64> instead of std::mt19937_64?

Copy link
Contributor Author

@eee4017 eee4017 Dec 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This design choice is highlighted in the original code, which can be reviewed here. The code snippet demonstrates an interesting pattern.

auto seed = GetRandomSeed();
std::seed_seq seq({seed});
auto engine = std::make_shared<std::mt19937_64>(seq);
this->state_.cpu_engine = *engine;
this->engine_ = engine;

An mt19937_64 engine is allocated on the heap, and GeneratorState retains a copy of this initial state. However, this copy is subsequently unused. Instead, Generator has a shared_ptr pointing to the mt19937_64 engine on the heap, and this pointer is used for engine manipulation, as seen in GetCPUEngine and Random64.

The original code seems to favor the mt19937_64 engine on the heap managed by the shared_ptr rather than the engine held by GeneratorState. If GeneratorState were to own the mt19937_64 engine directly, a shared_ptr would not be necessary since the engine's lifetime would align with the state's lifetime. In this case, a normal pointer would suffice.

I have considered letting GeneratorState directly own the mt19937_64 engine. However, this would entail changing the GetCPUEngine API to a normal pointer, potentially impacting numerous files that utilize this API.

uint64_t Generator::RegisterStateIndex(const GeneratorState& state) {
std::lock_guard<std::mutex> lock(mu_);
auto new_index = states_.size();
states_.push_back(state);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we use states_.push_back(state.clone()) here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, cloning would be the more appropriate approach in this context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #60310.

return this->state_.current_seed;
std::lock_guard<std::mutex> lock(mu_);
uint64_t seed = GetRandomSeed();
state().reset(seed);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code logic here is different from the original ones. The original one does not reset this.state_.offset but the new one reset state().offset.

Copy link
Contributor Author

@eee4017 eee4017 Dec 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reset of offset in the new implementation is an intentional change. When seed() is invoked, it implies the generation of a new random value for the seed and setting the current seed to this random value. Consequently, every time a seed is set, it is logical and necessary to reset the value of offset as well. This approach aligns with SetCurrentSeed in our repository and the implementation in PyTorch here, where the offset is reset whenever a new seed is set.

orig_rng_state = paddle.get_rng_state()
orig_rng_state_index = paddle.incubate.get_rng_state(use_index=True)
# register a new state and set that state with the seed, store the indices into states_
self.states_[name] = paddle.incubate.register_rng_state_as_index()
paddle.seed(seed)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we reorder paddle.seed() and self.states_[name] = ... in the wrong order here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the order of operations is correct as implemented. Here's a breakdown of the process:

Original Code:

orig_rng_state = paddle.get_rng_state()
paddle.seed(seed)
self.states_[name] = paddle.get_rng_state()
paddle.set_rng_state(orig_rng_state)

In this snippet, we first save the original RNG state. Then, we set a new seed and update self.states_[name] with the new RNG state. Finally, we restore the original RNG state.

Revised Code:

orig_rng_state_index = paddle.incubate.get_rng_state(use_index=True)
# register a new state and set that state with the seed, store the indices into states_
self.states_[name] = paddle.incubate.register_rng_state_as_index()
paddle.seed(seed)
paddle.incubate.set_rng_state(orig_rng_state_index, use_index=True)

In the revised code, we first obtain the index of the original RNG state. We then register a new RNG state and update self.states_[name] with the index of this new state. Then, we set the seed for the new state. Finally, we switch back to the old state.

XiaoguangHu01 pushed a commit that referenced this pull request Dec 21, 2023
…aph-Safe Tensor Parallelism (#58859)" (#60148)

This reverts commit 3bcdeef.
zyfncg added a commit that referenced this pull request Dec 22, 2023
…aph-Safe Tensor Parallelism (#58859)" (#60147)

* Revert "Enhanced RNG State Management with Index-Based Control for Graph-Safe Tensor Parallelism (#58859)"

This reverts commit 3bcdeef.

* fix complie bug

* fix complie
@eee4017 eee4017 mentioned this pull request Dec 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers NVIDIA
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants