-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEA] RAFT should ensure all its symbols are hidden from shared object libraries #1722
Comments
@cjnolet note that there are a few alternatives to explicitly annotating every function with gcc also provides a pragma based solution that allows modifying the visibility of many symbols at once:
You can also do the inverse of annotating non-binary APIs with There's pros and cons among all these options, so it'll come down to what makes the most sense for RAFT. |
Hi Jake, thanks for the excellent write up. It's really clear. What potential ramifications could it have if "Invoking kernel() results in potentially executing code you did not expect". What I can think of is:
Is this a concern? The failure case you mentioned where within one version of RAFT the cudaPTXversion of a kernel could be different is something that I already expected could happen. I think our dispatch mechanism should continue to work in this case (because it calls |
Yes. Getting the different PTX versions is just one example. In general, I think "running code that you did not except" is not good. The most trivial example would be if there was a bug in the older version of RAFT that was fixed in the newer version, and the library using the newer version inadvertently picks up the kernel with the bug. Another other example I've seen where this causes problems is if in one library invokes runtime APIs to configure attributes of a kernel, but instead the kernel from the other library gets invoked and the kernel configuration is ignored. |
Related to issues rapidsai#1511 and rapidsai#1490. Should perhaps make this static (can remove weak linkage) or hidden (should keep weak linkage). (See issue rapidsai#1722)
Fixes issue #1511. Make get_cache_idx a weak symbol (to allow linking multiple symbols) without marking it inline (to avoid compilation warnings that are promoted to errors in nvcc 12). Related issues: - #1490 - #1722 Related PRs: - #1732 - #1492 Authors: - Allard Hendriksen (https://github.com/ahendriksen) - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) URL: #1733
This marks all kernels in CUCO as `static` so that they have internal linkage and won't conflict when used by multiple DSOs. I didn't see a single shared/common header in cuco where I could place a `CUCO_KERNEL` macro so I modified each instance instead. While `cccl` went with a `__attribute__ ((visibility ("hidden")))` approach to help reduce RDC size, this approach seemed very invasive for cuco. This is due to the fact that we would need to pragma push and pop both gcc warnings and nvcc warnings in each cuco header so that we don't introduce any warnings. This is needed as the compiler incorrectly state that the `__attribute__ ((visibility ("hidden")))` has no side-effect. Context: rapidsai/cudf#14726 NVIDIA/cccl#166 rapidsai/raft#1722 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yunsong Wang <[email protected]>
Is your feature request related to a problem? Please describe.
As a user of RAFT, I would like to build a shared object library,
libA.so
, that internally uses RAFT function templates, including__global__
function templates.Today, RAFT does nothing to hide the visibility of its
__global__
function templates or any other host template functions, and by default these symbols have weak visibility. In short, this means if I link two dynamic librariesA.so
andB.so
into my application that both contain identical instantiations of a RAFT template, then the linker will discard one of the two instantiations and use only one of them. This can lead to disastrous and insidious issues like spurious silent failures.This issue is true of any header-only, C++ template library, but is particularly bad for CUDA C++ libraries that ship
__global__
function templates. Consider this trivial example of one of many ways things can go wrongThe following code has two TUs:
Each TU has a single function (
volta()
orpascal()
respectively) and this function queries and prints theptxVersion
of akernel<void>
usingcudaFuncGetAttributes
.These TUs are linked into a program that determines the compute capability of device 0 and invokes
volta()
orpascal()
accordingly.One would expect that invoking
volta()
would always print 70 and invokingpascal()
would print 60.However, this is not the case. As described above, the kernel template has weak linkage, and so when linking the
volta.o
andpascal.o
TUs together, the linker selects one of the instantiations ofkernel<void>
and discards the other.The end result is that the program will randomly print 60 or 70 depending on which instantiation the linker picked.
TL;DR:
Describe the solution you'd like
Luckily the solution is quite simple.
__global__
function should be annotated asstatic
__attribute__((visibility("hidden")))
.This makes the symbol hidden in any resulting dynamic library.
For convenience, you'd likely want to use a macro to wrap these like this:
Additional Context
We've been bitten by this in Thrust/CUB several times over the years.
Thrust/CUB also have the ability to allow users to customize the namespace in order to differentiate the symbols and avoid this problem. However, this solution is not robust. First of all, it requires every user to remember to customize the namespace. Secondly, it's possible for users to properly customize the namespace and still run afoul of the issues that can result.
See:
The text was updated successfully, but these errors were encountered: