Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC-0023 Unified Memory for Pytorch #36

Open
wants to merge 16 commits into
base: master
Choose a base branch
from

Conversation

jayfurmanek
Copy link

This RFC proposes to add Unified Virtual Memory (UVM) (or “Managed Memory”) function utilizing the managed memory allocation APIs available in CUDA/ROCm.

The proposed changes to the front end and back end have been minimized as much as possible to have a very targeted effect when UVM is enabled and have no effect at all when UVM is disabled, which will of course remain the default.

Please note that the details of these proposals are subject to revision given feedback from users and prototype testing. Please feel free to comment on the RFCs with your feedback

@facebook-github-bot
Copy link
Contributor

Hi @jayfurmanek!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@jayfurmanek
Copy link
Author

I signed it!

@facebook-github-bot
Copy link
Contributor

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@albanD
Copy link
Contributor

albanD commented Jan 18, 2022

@ngimel
Copy link

ngimel commented Jan 18, 2022

cc @mcarilli

1. UVM is enabled
2. The `.to()` only changes the device and does not change any other tensor attribute.

`Torch.cuda.uvm_to(Tensor t)`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure this is needed. Do you have specific use cases where the user would not want this behavior but get the original to() behavior?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We agree with your statement, and would like to propose using the 'copy' parameter in the 'to' function for a user to get the original behavior of the operation.
Such that tensor.to(device='cpu', copy=true) or the dtype or layout changes, we would force the tensor to allocate more storage and copy to the target device.

Otherwise, if the tensor.to() only changed the device type and copy=false(default), the operation would call a move_ function that uses the same memory, but changes DeviceType to the new device.

A use case for this would be if you wanted a tensor to be accessible on different devices, but allowed to be changed asychronously.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To illustrate, the following would always copy (assume the tensor is on cpu)

tensor.to(device='cpu', copy=true)
tensor.to(device='cuda:0', copy=false, dtype=torch.float32)

And this would call the move (zero copy)

tensor.to(device='cuda:0', copy=false)


## FAQ

#### Should we have set_enabled_uvm(bool)? Will there be a use case where the user can set this to false?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a technical reason for not providing this? Or just that we don't expect users to need this feature?

I could see a use case where a user does:

enable uvm
create the model and profile memory usage
relocate/optimize memory usage
disable uvm
run the training without uvm

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the result of our "big switch" approach of enabling UVM mode. Since memory must be allocated using a different API, having it always on, or off made the design simpler.

Perhaps being able to mark a tensor as "managed" would help your use-case above. We could add a new flag (managed=true) to the to() function that would be able to cast a tensor to a managed tensor. We could still have the big switches (and a big switch disable), but those would just dictate the default value of the managed parameter here.

This would allow more user control on which tensors are managed. So:

recast (reallocate) a non-managed tensor to a managed tensor

tensor.to(device="cuda:0", managed=true)

The reverse would also work.

- No explicit cudaMemcpy in either direction
- Set cudaMemPrefetchAsync()
- Device to Host synchronization
- There is no concept of a GPU “stream” on the host so Device-to-Host copies have a required synchronization step attached. With UVM enabled, no explicit copy will be scheduled, but synchronization will still be required.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify here what is expected from the end user? Does this mean that doing a to("cpu") and using the result straight away is unsafe? Or the allocator will make sure the data is available before we can read that memory?

If it is the first case, I think we should force a sync explicitly here unless the user passes non_blocking=True meaning that it takes on the responsibility of doing this sync.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct.
Honoring the non_blocking=true also makes sense. The "move kernel" could check that and skip the device sync if directed to, passing on the responsibility.


We propose to treat tensor.to(device) as a special case when UVM is enabled and `enabled_uvm_move_on_copy` has been set. Additionally, the behavior of tensor.uvm_to() will match the behavior of tensor.to() when UVM is enabled and `enabled_uvm_move_on_copy` is True.

![Existing and proposed allocator usage](./RFC-0023-assets/copy-diagram.jpg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what it means here for the cpu Tensor not to exist anymore? the copy op is not modifying the source inplace so I am not sure how this behavior would fit within pytorch.

Does this mean that there is no way to have b = a.to("cuda") (with a on cpu initially) and have:

  • a will keep behaving as a cpu Tensor
  • b will behave as a gpu Tensor
  • a and b share memory (in the sense that inplace ops on either of them is reflected on the other).

Effectively, this means that this .to() will only be a cross-device view op (meaning that the input and output share memory). The parallel to this in core is the torch.view_as_real() function for example that is a view op for which the input is complex and the output is float dtype.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In line with the statement I made up above, we are going to add the use of the copy=true parameter on the .to() function to allow an explicit copy to be made even when the tensor is managed by uvm.

For the torch.view_as_real() we are treating a .to() as an explicit change of DeviceType for the tensor. Even though the memory address may be the same, .to() will trigger a prefetch to tell the driver to start moving the memory to the target device. Necessary synchronizations and other actions will be taken as needed to ensure correct coherency.

Copy link

@dzhulgakov dzhulgakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One interesting thought for core composibility is to see whether this RFC can be prototyped outside of the main repo for faster iteration. The extension point for allocator should be already sufficient. Overriding kernels might be a bit more tricky because the copy kernel has a lot of craft in it that might be replicated. Maybe it's not worth it though.


- In `at::empty_cpu()`, there is already a switch to determine which cpu allocator to use (pinned vs cpu). When UVM is enabled, this switch will always return the CachingManagedAllocator. There are likely a few other places (other empties) that get an allocator as well.

- Per a find from the FBGEMM team, there is an issue with the context being moved from the device where the memory was allocated, and where it currently resides. We need to keep the original context, so we don’t create new ones unnecessarily on other devices. To account for this, at `at::empty` time, we will create two Storage structs:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clarification: 'context' here really refers to DataPtr of 'storage'. Since the Storage is bound to a device, when moving to a different device we need to create a new Storage instance and point it to the old one

Currently there are different allocators used for CPU (CPUAllocator) , pinned host memory (CachingHostAllocator), and device memory (CUDACachingAllocator) memory allocations.

- When UVM is enabled, we propose forgoing all traditional allocators and working with a newly developed CachingManagedAllocator.
- This will provide benefits in a few areas. We can continue to develop and improve on Unified memory allocations without the risk of introducing performance degradation, or errors to the existing CUDACachingAllocator. In addition, we expect the behavior of Managed memory to be different than that of traditional discrete GPU/CPU memory. For example, managed memory is available to all GPUs on all streams by default and can move between host and any number of GPUs at will. This will provide intricacies to the caching methodology that may conflict with the existing caching allocators.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

about the stream part - the reason caching allocator has to keep track of streams for deallocation is to make deallocation synchronization free and be able to reuse the same buffer on another stream (after injecting a pair of events). I'd imagine you need to do the same here too

but in general agree on having separate allocator and being able to customize behavior

@dllehr-amd
Copy link

@dzhulgakov Do you have any examples of using an allocator extension point? I haven't found much documentation on it, and I may be looking in the wrong areas. If you have any docs etc. that provide more clarification on this I'd love to look at it. Thanks!!

@dzhulgakov
Copy link

dzhulgakov commented Jan 31, 2022

Do you have any examples of using an allocator extension point?

For CPU allocator there's some proper wiring in the core library allowing to override the allocator used for new tensors: https://github.com/pytorch/pytorch/blob/master/c10/core/Allocator.h#L210 and it works as the allocations from empty_cpu just go to GetAllocator (https://github.com/pytorch/pytorch/blob/master/c10/core/CPUAllocator.cpp#L138).

Unfortunately, I originally forgot that CUDA allocator is not wired through this interface. Some use cases do use it via Allocator interface. But the majority of places call CUDACachingAllocator directly. CUDACachingAllocator is not even a namespace, not a class atm: https://github.com/pytorch/pytorch/blob/72c972e1e1b4ad838de604e35269e200a70db5f2/c10/cuda/CUDACachingAllocator.h#L32 (I remember we wanted to fix it, but sadly never did).

That means that for immediate prototyping of this RFC we'd need to do something like modifying the code in-place (i.e. in a temporary fork).

The more extensible course of action would be to turn CUDACachingAllocator into a proper class and allow to override it with a different implementation like for the CPU one. The interface of it would need to be broader than the base Allocator interface though as there's the need for recordStream. It's a worthy refactoring, would you be interested in taking it on?

cc @ezyang

@ngimel
Copy link

ngimel commented Jan 31, 2022

pytorch/pytorch#65365 adds an alternative backend for cuda allocations (and puts backends in the namespaces, THC and CUDAMallocAsync). Design is not finalized (and THC name definitely has to go), but probably we can decide what we need to do in that PR, as it will need to go in the core soon-ish, probably sooner than unified memory.

@mcarilli
Copy link

mcarilli commented Jan 31, 2022

The more extensible course of action would be to turn CUDACachingAllocator into a proper class

In pytorch/pytorch#65365, to avoid a vtable lookup and maintain existing inlining opportunities, I deliberately chose a non-polymorphic design. Each (inline) interface function in CUDACachingAllocator.h installs the correct static function pointer which it uses from then on, for example:

inline void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) {
  static auto f = (std::strcmp(allocatorBackend(), "native") == 0) ?
    THC::raw_alloc_with_stream : CudaMallocAsync::raw_alloc_with_stream;
  return f(nbytes, stream);
}

There are only two backends now, but macros could allow extensibility without too much boilerplate.
I did it this way because I thought it was neat and accommodated existing implementations cleanly without big refactors. I don't know if avoiding the vtable lookup REALLY makes a performance difference. Probably not, but then again, this code is called often.

That being said, I'm not sure we need a full-blown unique allocator implementation to accommodate managed memory. My first instinct is most (or all!) of what a cached managed memory allocator would need is identical to what the native allocator already does for cudaMalloc...so maybe we could just s/cudaMalloc/cudaMallocManaged/g in CUDACachingAllocator.cpp (or add an envvar-based conditional choice between them at all sites where cudaMalloc is currently called) and blow imaginary smoke off our finger guns.

We'd need to add a few calls to expose managed memory's prefetch-to-gpu and prefetch-to-cpu functionality, but my point is I think most of the caching code could be shared with the native allocator.

@jayfurmanek
Copy link
Author

Ah, Thanks for the insight, @mcarilli, we were just debating if a proper class for the allocator makes sense and hadn't considered the possible inline benefits. That's interesting.

.so maybe we could just s/cudaMalloc/cudaMallocManaged/g in CUDACachingAllocator.cpp

That was our first instinct as well! There are a few complications, or at least decisions, that have to be resolved. One is the CUDACachingAllocator has a DeviceAllocator for each present CUDA device, and the caching is relevant to each. With managed memory, it's abstracted, so that allocator hierarchy makes less sense. Prefetch and device hints are used for data locality suggestions and those need to happen somewhere. Another question is do we want to allow some designated tensors to be managed, or just have a big switch to do all or none.

Our design proposes a single allocator for all devices (CPU too). Initially the caching will look like the existing CUDACachingAllocator, but there would be room for extension if needed (for CPU focused data loading for example).

@ezyang
Copy link
Contributor

ezyang commented Feb 10, 2022

What I'd like to see are some examples / case studies of when you would use UVM, as opposed to traditional management. What is it good for, what is a typical user program that is using UVM going to look like? How would you have written this code in straight line C? If you want to be ROCm specific that's fine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants