Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace RMM CUDA Python bindings with those provided by CUDA-Python #451

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conda/environments/raft_dev_cuda11.5.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
dependencies:
- cudatoolkit=11.5
- cuda-python >=11.5,<12.0
- clang=11.1.0
- clang-tools=11.1.0
- rapids-build-env=22.02.*
Expand Down
24 changes: 5 additions & 19 deletions python/raft/common/cuda.pxd
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019, NVIDIA CORPORATION.
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,23 +14,9 @@
# limitations under the License.
#

# cython: profile=False
# distutils: language = c++
# cython: embedsignature = True
# cython: language_level = 3
from cuda.ccudart cimport cudaStream_t

cdef class Stream:
cdef cudaStream_t s

# Populate this with more typedef's (eg: events) as and when needed
cdef extern from * nogil:
ctypedef void* _Stream "cudaStream_t"
ctypedef int _Error "cudaError_t"


# Populate this with more runtime api method declarations as and when needed
cdef extern from "cuda_runtime_api.h" nogil:
_Error cudaStreamCreate(_Stream* s)
_Error cudaStreamDestroy(_Stream s)
_Error cudaStreamSynchronize(_Stream s)
_Error cudaGetLastError()
const char* cudaGetErrorString(_Error e)
const char* cudaGetErrorName(_Error e)
cdef cudaStream_t getStream(self)
47 changes: 23 additions & 24 deletions python/raft/common/cuda.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -19,10 +19,22 @@
# cython: embedsignature = True
# cython: language_level = 3

from cuda.ccudart cimport(
cudaStream_t,
cudaError_t,
cudaSuccess,
cudaStreamCreate,
cudaStreamDestroy,
cudaStreamSynchronize,
cudaGetLastError,
cudaGetErrorString,
cudaGetErrorName
)


class CudaRuntimeError(RuntimeError):
def __init__(self, extraMsg=None):
cdef _Error e = cudaGetLastError()
cdef cudaError_t e = cudaGetLastError()
cdef bytes errMsg = cudaGetErrorString(e)
cdef bytes errName = cudaGetErrorName(e)
msg = "Error! %s reason='%s'" % (errName.decode(), errMsg.decode())
Expand All @@ -45,29 +57,17 @@ cdef class Stream:
stream.sync()
del stream # optional!
"""

# NOTE:
# If we store _Stream directly, this always leads to the following error:
# "Cannot convert Python object to '_Stream'"
# I was unable to find a good solution to this in reasonable time. Also,
# since cudaStream_t is a pointer anyways, storing it as an integer should
# be just fine (although, that certainly is ugly and hacky!).
cdef size_t s

def __cinit__(self):
if self.s != 0:
return
shwina marked this conversation as resolved.
Show resolved Hide resolved
cdef _Stream stream
cdef _Error e = cudaStreamCreate(&stream)
if e != 0:
cdef cudaStream_t stream
cdef cudaError_t e = cudaStreamCreate(&stream)
if e != cudaSuccess:
raise CudaRuntimeError("Stream create")
self.s = <size_t>stream
self.s = stream

def __dealloc__(self):
self.sync()
cdef _Stream stream = <_Stream>self.s
cdef _Error e = cudaStreamDestroy(stream)
if e != 0:
cdef cudaError_t e = cudaStreamDestroy(self.s)
if e != cudaSuccess:
raise CudaRuntimeError("Stream destroy")

def sync(self):
Expand All @@ -76,10 +76,9 @@ cdef class Stream:
could raise exception due to issues with previous asynchronous
launches
"""
cdef _Stream stream = <_Stream>self.s
cdef _Error e = cudaStreamSynchronize(stream)
if e != 0:
cdef cudaError_t e = cudaStreamSynchronize(self.s)
if e != cudaSuccess:
raise CudaRuntimeError("Stream sync")

def getStream(self):
cdef cudaStream_t getStream(self):
return self.s
3 changes: 1 addition & 2 deletions python/raft/common/handle.pxd
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -21,7 +21,6 @@


from libcpp.memory cimport shared_ptr
from .cuda cimport _Stream
from rmm._lib.cuda_stream_view cimport cuda_stream_view
from rmm._lib.cuda_stream_pool cimport cuda_stream_pool
from libcpp.memory cimport shared_ptr
Expand Down
9 changes: 5 additions & 4 deletions python/raft/common/handle.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -24,9 +24,10 @@ from libcpp.memory cimport shared_ptr
from rmm._lib.cuda_stream_view cimport cuda_stream_per_thread
from rmm._lib.cuda_stream_view cimport cuda_stream_view

from .cuda cimport _Stream, _Error, cudaStreamSynchronize
from .cuda cimport Stream
from .cuda import CudaRuntimeError


cdef class Handle:
"""
Handle is a lightweight python wrapper around the corresponding C++ class
Expand All @@ -51,7 +52,7 @@ cdef class Handle:
del handle # optional!
"""

def __cinit__(self, stream=None, n_streams=0):
def __cinit__(self, stream: Stream = None, n_streams=0):
self.n_streams = n_streams
if n_streams > 0:
self.stream_pool.reset(new cuda_stream_pool(n_streams))
Expand All @@ -64,7 +65,7 @@ cdef class Handle:
self.stream_pool))
else:
# this constructor constructs a handle on user stream
c_stream = cuda_stream_view(<_Stream><size_t> stream.getStream())
c_stream = cuda_stream_view(stream.getStream())
self.c_obj.reset(new handle_t(c_stream,
self.stream_pool))

Expand Down