Skip to content

Commit

Permalink
Rename createCUDAStream() to getStreamFromPool() (pytorch#12940)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#12940

Dmytro was reading this code and requested that we rename the interface
to something that made it more obvious that pooling was going on.
Seems reasonable to me! Final name is a suggestion from Pieter.

Reviewed By: dzhulgakov

Differential Revision: D10492071

fbshipit-source-id: b1c2cac760f666968d58166be649dabfe1127c5e
  • Loading branch information
ezyang authored and facebook-github-bot committed Oct 24, 2018
1 parent 924326e commit ca03c10
Show file tree
Hide file tree
Showing 11 changed files with 37 additions and 26 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/cuda/CUDAContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ cudaDeviceProp* getDeviceProperties(int64_t device) {
}

/* Streams */
CUDAStream createCUDAStream(
CUDAStream getStreamFromPool(
const bool isHighPriority
, int64_t device) {
return detail::CUDAStream_createStream(isHighPriority, device);
return detail::CUDAStream_getStreamFromPool(isHighPriority, device);
}

CUDAStream getDefaultCUDAStream(int64_t device) {
Expand Down
13 changes: 12 additions & 1 deletion aten/src/ATen/cuda/CUDAContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,19 @@ CAFFE2_API cudaDeviceProp* getCurrentDeviceProperties();
CAFFE2_API cudaDeviceProp* getDeviceProperties(int64_t device);

/* Streams */

/**
* Get a new stream from the CUDA stream pool. You can think of this
* as "creating" a new stream, but no such creation actually happens;
* instead, streams are preallocated from the pool and returned in a
* round-robin fashion.
*
* You can request a stream from the high priority pool by setting
* isHighPriority to true, or a stream for a specific device by setting device
* (defaulting to the current CUDA stream.)
*/
CAFFE2_API CUDAStream
createCUDAStream(const bool isHighPriority = false, int64_t device = -1);
getStreamFromPool(const bool isHighPriority = false, int64_t device = -1);

CAFFE2_API CUDAStream getDefaultCUDAStream(int64_t device = -1);
CAFFE2_API CUDAStream getCurrentCUDAStream(int64_t device = -1);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/CUDAStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ static uint32_t get_idx(std::atomic<uint32_t> &counter) {
// Returns a stream from the requested pool
// Note: when called the first time on a device, this will create the
// stream pools for that device.
CUDAStreamInternals* CUDAStream_createStream(
CUDAStreamInternals* CUDAStream_getStreamFromPool(
const bool isHighPriority
, int64_t device) {
initCUDAStreamsOnce();
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/CUDAStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ namespace detail {
// Pointer-based API (for internal use)
AT_CUDA_API CUDAStreamInternals* CUDAStream_getDefaultStream(int64_t device = -1);

AT_CUDA_API CUDAStreamInternals* CUDAStream_createStream(
AT_CUDA_API CUDAStreamInternals* CUDAStream_getStreamFromPool(
const bool isHighPriority = false
, int64_t device = -1);

Expand Down
30 changes: 15 additions & 15 deletions aten/src/ATen/test/stream_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ TEST(TestStream, CopyAndMoveTest) {
// Tests that copying works as expected and preserves the stream
at::cuda::CUDAStream copyStream;
{
auto s = at::cuda::createCUDAStream();
auto s = at::cuda::getStreamFromPool();
device = s.device();
cuda_stream = s.stream();

Expand All @@ -51,7 +51,7 @@ TEST(TestStream, CopyAndMoveTest) {
// Tests that moving works as expected and preserves the stream
at::cuda::CUDAStream moveStream;
{
auto s = at::cuda::createCUDAStream();
auto s = at::cuda::getStreamFromPool();
device = s.device();
cuda_stream = s.stream();

Expand All @@ -68,7 +68,7 @@ TEST(TestStream, CopyAndMoveTest) {

// Verifies streams are set properly
TEST(TestStream, GetAndSetTest) {
at::cuda::CUDAStream myStream = at::cuda::createCUDAStream();
at::cuda::CUDAStream myStream = at::cuda::getStreamFromPool();

// Sets and gets
at::cuda::setCurrentCUDAStream(myStream);
Expand All @@ -86,7 +86,7 @@ TEST(TestStream, GetAndSetTest) {
}

void thread_fun(at::cuda::CUDAStream& cur_thread_stream) {
auto new_stream = at::cuda::createCUDAStream();
auto new_stream = at::cuda::getStreamFromPool();
at::cuda::setCurrentCUDAStream(new_stream);
cur_thread_stream = at::cuda::getCurrentCUDAStream();
ASSERT_EQ_CUDA(cur_thread_stream, new_stream);
Expand Down Expand Up @@ -120,7 +120,7 @@ TEST(TestStream, CUDAGuardTest) {

ASSERT_EQ_CUDA(at::cuda::current_device(), 0);
std::vector<at::cuda::CUDAStream> streams0 = {
at::cuda::getDefaultCUDAStream(), at::cuda::createCUDAStream()};
at::cuda::getDefaultCUDAStream(), at::cuda::getStreamFromPool()};
ASSERT_EQ_CUDA(streams0[0].device(), 0);
ASSERT_EQ_CUDA(streams0[1].device(), 0);
at::cuda::setCurrentCUDAStream(streams0[0]);
Expand All @@ -129,7 +129,7 @@ TEST(TestStream, CUDAGuardTest) {
{
at::DeviceGuard device_guard(1);
streams1.push_back(at::cuda::getDefaultCUDAStream());
streams1.push_back(at::cuda::createCUDAStream());
streams1.push_back(at::cuda::getStreamFromPool());
}
ASSERT_EQ_CUDA(streams1[0].device(), 1);
ASSERT_EQ_CUDA(streams1[1].device(), 1);
Expand Down Expand Up @@ -190,7 +190,7 @@ TEST(TestStream, CUDAGuardMovableTest) {
if (at::cuda::getNumGPUs() < 2) {
return;
}
const auto stream = at::cuda::createCUDAStream();
const auto stream = at::cuda::getStreamFromPool();
const auto device_count = at::cuda::getNumGPUs();
at::cuda::CUDAGuard first(stream);
first.set_device(1);
Expand All @@ -209,7 +209,7 @@ TEST(TestStream, CUDAGuardMovableTest) {
TEST(TestStream, StreamPoolTest) {
std::vector<at::cuda::CUDAStream> streams{};
for (int i = 0; i < 200; ++i) {
streams.emplace_back(at::cuda::detail::CUDAStream_createStream());
streams.emplace_back(at::cuda::detail::CUDAStream_getStreamFromPool());
}

std::unordered_set<cudaStream_t> stream_set{};
Expand All @@ -229,8 +229,8 @@ TEST(TestStream, MultiGPUTest) {
if (at::cuda::getNumGPUs() < 2)
return;

at::cuda::CUDAStream s0 = at::cuda::createCUDAStream(true, 0);
at::cuda::CUDAStream s1 = at::cuda::createCUDAStream(false, 1);
at::cuda::CUDAStream s0 = at::cuda::getStreamFromPool(true, 0);
at::cuda::CUDAStream s1 = at::cuda::getStreamFromPool(false, 1);

at::cuda::setCurrentCUDAStream(s0);
at::cuda::setCurrentCUDAStream(s1);
Expand All @@ -243,15 +243,15 @@ TEST(TestStream, MultiGPUTest) {

// CUDAEvent Syncs
TEST(TestStream, CUDAEventSyncTest) {
const auto stream = at::cuda::createCUDAStream();
const auto stream = at::cuda::getStreamFromPool();
at::cuda::CUDAEvent event;

ASSERT_FALSE(event.happened());

event.recordOnce(stream);

const auto wait_stream0 = at::cuda::createCUDAStream();
const auto wait_stream1 = at::cuda::createCUDAStream();
const auto wait_stream0 = at::cuda::getStreamFromPool();
const auto wait_stream1 = at::cuda::getStreamFromPool();

wait_stream0.synchronize_with(event);
wait_stream1.synchronize_with(event);
Expand All @@ -265,11 +265,11 @@ TEST(TestStream, CrossDeviceTest) {
if (at::cuda::getNumGPUs() < 2)
return;

const auto stream0 = at::cuda::createCUDAStream();
const auto stream0 = at::cuda::getStreamFromPool();
at::cuda::CUDAEvent event0;

at::cuda::set_device(1);
const auto stream1 = at::cuda::createCUDAStream();
const auto stream1 = at::cuda::getStreamFromPool();
at::cuda::CUDAEvent event1;

event0.record(stream0);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/THC/THCStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ THC_API THCStream* THCStream_defaultStream(int device) {
}

THC_API THCStream* THCStream_new() {
return at::cuda::detail::CUDAStream_createStream();
return at::cuda::detail::CUDAStream_getStreamFromPool();
}

THC_API cudaStream_t THCStream_stream(THCStream* stream) {
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/cuda/Stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static PyObject * THCPStream_pynew(PyTypeObject *type, PyObject *args, PyObject
stream = (THCStream*) cdata;
} else {
const bool isHighPriority = priority < 0 ? true : false;
stream = at::cuda::detail::CUDAStream_createStream(isHighPriority);
stream = at::cuda::detail::CUDAStream_getStreamFromPool(isHighPriority);
}

THCPStream* self = (THCPStream *)ptr.get();
Expand Down
2 changes: 1 addition & 1 deletion torch/lib/c10d/ProcessGroupGloo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ EntryType ProcessGroupGloo::construct(const AlgorithmKey& key) {
entry->events.resize(key.devices.size());
for (size_t i = 0; i < key.devices.size(); i++) {
deviceGuard.set_index(key.devices[i]);
entry->streams[i] = at::cuda::createCUDAStream();
entry->streams[i] = at::cuda::getStreamFromPool();
entry->events[i] = CUDAEvent::create();
}
}
Expand Down
2 changes: 1 addition & 1 deletion torch/lib/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID);

// Also create the NCCL streams and events
streamVal[i] = at::cuda::createCUDAStream();
streamVal[i] = at::cuda::getStreamFromPool();
// Event created using cudaEventDisableTiming flag and not
// cudaEventBlockingSync flag will provide the best performance when used
// with cudaStreamWaitEvent() and cudaEventQuery(). Since we here don't
Expand Down
2 changes: 1 addition & 1 deletion torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class AsyncInputIsOutputTest : public AsyncTest {
streams_.resize(numDevices_);
for (auto i = 0; i < numDevices_; i++) {
deviceGuard.set_index(i);
streams_[i] = at::cuda::createCUDAStream();
streams_[i] = at::cuda::getStreamFromPool();
}
}

Expand Down
2 changes: 1 addition & 1 deletion torch/lib/c10d/test/ProcessGroupNCCLTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class NCCLTest : public NCCLTestBase {
streams_.resize(numDevices_);
for (auto i = 0; i < numDevices_; i++) {
deviceGuard.set_index(i);
streams_[i] = at::cuda::createCUDAStream();
streams_[i] = at::cuda::getStreamFromPool();
}
}

Expand Down

0 comments on commit ca03c10

Please sign in to comment.