Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Allow dynamic shmem of size > 48K in runtime #11478

Merged
merged 4 commits into from
May 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions python/tvm/contrib/cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import sys
import os
import subprocess
import logging

from .._ffi.base import py_str

Expand Down Expand Up @@ -239,7 +238,6 @@ def _linux_compile(output, objects, options, compile_cmd, compile_shared=False):
cmd += objects
if options:
cmd += options
logging.info("invoking '%s'", cmd)
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
Expand All @@ -266,7 +264,6 @@ def _windows_compile(output, objects, options):
cmd += options

try:
logging.info("invoking '%s'", cmd)
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
except FileNotFoundError:
Expand Down
2 changes: 0 additions & 2 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import subprocess
import os
import warnings
import logging

import tvm._ffi
from tvm.target import Target
Expand Down Expand Up @@ -103,7 +102,6 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target
# if cxx_compiler_path != "":
# cmd += ["-ccbin", cxx_compiler_path]

logging.info("invoking '%s'", cmd)
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

(out, _) = proc.communicate()
Expand Down
13 changes: 12 additions & 1 deletion src/runtime/cuda/cuda_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,22 @@ class CUDAWrappedFunc {
void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const {
int device_id;
CUDA_CALL(cudaGetDevice(&device_id));
ThreadWorkLoad wl = launch_param_config_.Extract(args);

if (fcache_[device_id] == nullptr) {
fcache_[device_id] = m_->GetFunc(device_id, func_name_);
if (wl.dyn_shmem_size >= (48 << 10)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if dynamic memory is too large, will it pass VerifyGPUCode check?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't tested but yeah, it seems VerifyGPUCode checks the static alloc size against max_shared_memory_per_block, which would fail if dyn_shmem_size >= (48 << 10)

} else if (storage_scope.rank == runtime::StorageRank::kShared) {
size_t size = static_cast<size_t>(op->ConstantAllocationSize());

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we defer this issue later? I need this to demonstrate that a multi-stage pipeline with depth > 2 works on a semi-realistic cuda schedule.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah let's defer this particular issue

// Assumption: dyn_shmem_size doesn't change across different invocations of
// fcache_[device_id]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumption could be controversial, but this should be mostly ok in practice. To support a kernel which uses different big shmem sizes depending on input, we need to call cuFuncSetAttribute on every invocation.

CUresult result = cuFuncSetAttribute(
fcache_[device_id], CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, wl.dyn_shmem_size);
if (result != CUDA_SUCCESS) {
LOG(FATAL) << "Failed to set the allowed dynamic shared memory size to "
<< wl.dyn_shmem_size;
}
}
}
CUstream strm = static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream);
ThreadWorkLoad wl = launch_param_config_.Extract(args);
CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1),
wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1),
wl.block_dim(2), wl.dyn_shmem_size, strm, void_args, nullptr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_cascade(SRAM, FLASH, TwoConv2DWithSliceTE, TwoConv2DTE, MobileNetv1Star
cs.cascade(sch, te_graph, const_dict, options, SRAM, FLASH, [SRAM], device_config)


@pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/11483")
def test_compute_cycles_annotation(SRAM, FLASH, TwoConv2DTE):
device_config = cs.EthosuDeviceConfig("ethos-u55-256")
options = infra.make_options(
Expand Down