Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[thunder] RuntimeError: INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/runtime/fusion_kernel_runtime.cpp":358 #3176

Open
kshitij12345 opened this issue Oct 14, 2024 · 0 comments

Comments

@kshitij12345
Copy link

Repro string from the error

An error occurred while executing nvFuser FusionDefinition 32.
If you believe this is a bug or need assistance, please file an issue at https://github.com/NVIDIA/Fuser/issues/new
Here's a script to reproduce the error:
```python
# CUDA devices:
#  0: NVIDIA RTX 6000 Ada Generation
#  1: NVIDIA RTX 6000 Ada Generation
# torch version: 2.6.0a0+gita777dea
# cuda version: 12.6
# nvfuser version: 0.2.15+git7616b54
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id32(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[1, 597, 128], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.define_tensor(shape=[1, 32, 597, 128], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 1, 2, 0])
    T2 = fd.define_tensor(shape=[1, 32, 597, 128], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
    T3 = fd.define_tensor(shape=[1, 32, 597, 128], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 1, 2, 0])
    T4 = fd.define_tensor(shape=[1, 597, 128], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T5 = fd.define_tensor(shape=[1, 32, 597, 128], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 1, 2, 0])
    T6 = fd.define_tensor(shape=[1, 32, 597, 128], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
    S7 = fd.define_scalar(1, dtype=DataType.Int)
    S8 = fd.define_scalar(1, dtype=DataType.Int)
    S9 = fd.define_scalar(597, dtype=DataType.Int)
    S10 = fd.define_scalar(128, dtype=DataType.Int)
    T12 = fd.ops.broadcast_in_dim(T0, shape=[S7, S8, S9, S10], broadcast_dims=[0, 2, 3])
    S13 = fd.define_scalar(1, dtype=DataType.Int)
    S14 = fd.define_scalar(32, dtype=DataType.Int)
    S15 = fd.define_scalar(597, dtype=DataType.Int)
    S16 = fd.define_scalar(128, dtype=DataType.Int)
    T18 = fd.ops.broadcast_in_dim(T12, shape=[S13, S14, S15, S16], broadcast_dims=[0, 1, 2, 3])
    T19 = fd.ops.cast(T1, dtype=DataType.Float)
    T20 = fd.ops.cast(T2, dtype=DataType.Float)
    T21 = fd.ops.cast(T3, dtype=DataType.Float)
    T22 = fd.ops.cast(T18, dtype=DataType.Float)
    T23 = fd.ops.add(T20, T19)
    T24 = fd.ops.mul(T22, T21)
    T25 = fd.ops.mul(T22, T23)
    T26 = fd.ops.cast(T24, dtype=DataType.BFloat16)
    T27 = fd.ops.cast(T25, dtype=DataType.BFloat16)
    T43 = fd.ops.slice(T26, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 597, 64], strides=[1, 1, 1, 1])
    T59 = fd.ops.slice(T27, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 597, 64], strides=[1, 1, 1, 1])
    T60 = fd.ops.cast(T43, dtype=DataType.Float)
    T61 = fd.ops.cast(T59, dtype=DataType.Float)
    T62 = fd.ops.neg(T60)
    T63 = fd.ops.neg(T61)
    S64 = fd.define_scalar(1, dtype=DataType.Int)
    S65 = fd.define_scalar(1, dtype=DataType.Int)
    S66 = fd.define_scalar(597, dtype=DataType.Int)
    S67 = fd.define_scalar(128, dtype=DataType.Int)
    T69 = fd.ops.broadcast_in_dim(T4, shape=[S64, S65, S66, S67], broadcast_dims=[0, 2, 3])
    T85 = fd.ops.slice(T26, start_indices=[0, 0, 0, 64], end_indices=[1, 32, 597, 128], strides=[1, 1, 1, 1])
    T86 = fd.ops.cast(T62, dtype=DataType.BFloat16)
    T102 = fd.ops.slice(T27, start_indices=[0, 0, 0, 64], end_indices=[1, 32, 597, 128], strides=[1, 1, 1, 1])
    T103 = fd.ops.cast(T63, dtype=DataType.BFloat16)
    S104 = fd.define_scalar(1, dtype=DataType.Int)
    S105 = fd.define_scalar(32, dtype=DataType.Int)
    S106 = fd.define_scalar(597, dtype=DataType.Int)
    S107 = fd.define_scalar(128, dtype=DataType.Int)
    T109 = fd.ops.broadcast_in_dim(T69, shape=[S104, S105, S106, S107], broadcast_dims=[0, 1, 2, 3])
    S110 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T111 = fd.ops.pad(T85, [0, 64, 0, 0, 0, 0, 0, 0], S110)
    S112 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T113 = fd.ops.pad(T86, [64, 0, 0, 0, 0, 0, 0, 0], S112)
    S114 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T115 = fd.ops.pad(T102, [0, 64, 0, 0, 0, 0, 0, 0], S114)
    S116 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T117 = fd.ops.pad(T103, [64, 0, 0, 0, 0, 0, 0, 0], S116)
    T118 = fd.ops.cast(T109, dtype=DataType.Float)
    T119 = fd.ops.cast(T111, dtype=DataType.Float)
    T120 = fd.ops.cast(T113, dtype=DataType.Float)
    T121 = fd.ops.cast(T115, dtype=DataType.Float)
    T122 = fd.ops.cast(T117, dtype=DataType.Float)
    T123 = fd.ops.mul(T118, T21)
    T124 = fd.ops.add(T120, T119)
    T125 = fd.ops.mul(T118, T23)
    T126 = fd.ops.add(T122, T121)
    T127 = fd.ops.cast(T5, dtype=DataType.Float)
    T128 = fd.ops.cast(T6, dtype=DataType.Float)
    T129 = fd.ops.add(T124, T123)
    T130 = fd.ops.add(T126, T125)
    T131 = fd.ops.add(T128, T127)
    T132 = fd.ops.cast(T129, dtype=DataType.BFloat16)
    T133 = fd.ops.cast(T130, dtype=DataType.BFloat16)
    T134 = fd.ops.cast(T131, dtype=DataType.BFloat16)
    T135 = fd.ops.permute(T132, dims=[0, 2, 1, 3])
    T136 = fd.ops.permute(T133, dims=[0, 2, 1, 3])
    T137 = fd.ops.permute(T134, dims=[0, 2, 1, 3])
    T142 = fd.ops.reshape(T135, new_shape=[1, 597, 4096])
    T147 = fd.ops.reshape(T136, new_shape=[1, 597, 4096])
    T152 = fd.ops.reshape(T137, new_shape=[1, 597, 4096])
    T156 = fd.ops.reshape(T142, new_shape=[597, 4096])
    T160 = fd.ops.reshape(T147, new_shape=[597, 4096])
    T164 = fd.ops.reshape(T152, new_shape=[597, 4096])
    T165 = fd.ops.permute(T156, dims=[1, 0])
    T166 = fd.ops.permute(T160, dims=[1, 0])
    T167 = fd.ops.permute(T164, dims=[1, 0])
    fd.add_output(T164)
    fd.add_output(T167)
    fd.add_output(T160)
    fd.add_output(T166)
    fd.add_output(T156)
    fd.add_output(T165)

with FusionDefinition() as fd:
    nvfuser_fusion_id32(fd)

inputs = [
    torch.randn(76416, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 597, 128), (76416, 128, 1)),
    torch.randn(2445312, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 32, 597, 128), (2445312, 128, 4096, 1)),
    torch.randn(2445312, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 32, 597, 128), (2445312, 76416, 128, 1)),
    torch.randn(2445312, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 32, 597, 128), (2445312, 128, 4096, 1)),
    torch.randn(76416, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 597, 128), (76416, 128, 1)),
    torch.randn(2445312, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 32, 597, 128), (2445312, 128, 4096, 1)),
    torch.randn(2445312, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 32, 597, 128), (2445312, 76416, 128, 1)),
]
fd.execute(inputs)
Stack Trace from the larger program
Traceback (most recent call last):
File "/opt/pytorch/nvfuser/nvfuser/init.py", line 181, in execute
results = self._execute(
RuntimeError: INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/runtime/fusion_kernel_runtime.cpp":358, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Detected exception while compiling fusion segments in parallel. Error messages from all threads are printed below.

Error from segmentation group 10: INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/index_compute.cpp":1965, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Couldn't find allocation mapping for T71_l___bfloat[ iblockIdx.x488{( ceilDiv(( ceilDiv(128, 4) ), blockDim.x) )}, iblockIdx.y491{( ceilDiv(( 1 * 597 ), 1) )}, iUS492{1}, iV487{4}, ithreadIdx.x489{blockDim.x} ] ca_pos( 3 ) dim: 2 id: iS301{128}, loops: iblockIdx.x335{( ceilDiv(( ceilDiv(4096, 4) ), blockDim.x) )} iblockIdx.y435{( ceilDiv(( 1 * 597 ), 1) )} iUS436{1} iV487{4} ithreadIdx.x489{blockDim.x}
Exception raised from getNonGlobalConsumerStridedIndices at /opt/pytorch/nvfuser/csrc/index_compute.cpp:1965 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&) + 0xf3 (0x7623932bd099 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&) + 0x53 (0x76239366f303 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/2: + 0x5d9518 (0x762393758518 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/3: + 0x5d981f (0x76239375881f in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/4: + 0x5d9e70 (0x762393758e70 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/5: + 0x41df4e (0x76239359cf4e in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/6: + 0x425aeb (0x7623935a4aeb in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/7: + 0x4273bf (0x7623935a63bf in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/8: + 0x4243ef (0x7623935a33ef in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/9: + 0x4243ef (0x7623935a33ef in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/10: + 0x4243ef (0x7623935a33ef in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/11: + 0x4273bf (0x7623935a63bf in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/12: + 0x4243ef (0x7623935a33ef in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/13: + 0x4243ef (0x7623935a33ef in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/14: + 0x41d8bb (0x76239359c8bb in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/15: + 0x3eb0ae (0x76239356a0ae in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/16: nvfuser::GpuLower::run() + 0x239 (0x762393565fe9 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/17: nvfuser::FusionExecutor::compileFusion(nvfuser::Fusion*, nvfuser::KernelArgumentHolder const&, nvfuser::LaunchParams const&, nvfuser::CompileParams, nvfuser::SchedulerType, long, long, long, long) + 0xa8a (0x762393938b1a in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/18: + 0x7f1572 (0x762393970572 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/19: + 0x7f183c (0x76239397083c in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/20: c10::ThreadPool::main_loop(unsigned long) + 0x2bd (0x762478d87fbd in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/21: + 0xdc253 (0x7624bf198253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame https://github.com/Lightning-AI/lightning-thunder/pull/22: + 0x94ac3 (0x7624bf37eac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame https://github.com/Lightning-AI/lightning-thunder/pull/23: + 0x126850 (0x7624bf410850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Use NVFUSER_DISABLE=parallel_compile to simplify error message.
Exception raised from compileFusionParallel at /opt/pytorch/nvfuser/csrc/runtime/fusion_kernel_runtime.cpp:358 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&) + 0xf3 (0x7623932bd099 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&) + 0x53 (0x76239366f303 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/2: nvfuser::FusionKernelRuntime::compileFusionParallel(nvfuser::KernelArgumentHolder) + 0xf9d (0x76239397213d in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/3: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRefc10::IValue const&, std::optionalnvfuser::PrimDataType, std::optional) + 0x1cb (0x76239396befb in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/4: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRefc10::IValue const&, std::optional, bool, bool, bool) const + 0x796 (0x762393ad9d76 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/5: + 0x1ca17e (0x76239334917e in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/6: + 0x24838f (0x7623933c738f in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/7: + 0x2dd850 (0x76239345c850 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)

frame https://github.com/Lightning-AI/lightning-thunder/pull/32: torch::autograd::PyNode::apply(std::vector<at::Tensor, std::allocatorat::Tensor >&&) + 0x95 (0x7624b21b3c75 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame https://github.com/Lightning-AI/lightning-thunder/issues/33: + 0x4fb908b (0x7624a992708b in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/34: torch::autograd::Engine::evaluate_function(std::shared_ptrtorch::autograd::GraphTask&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptrtorch::autograd::ReadyQueue const&) + 0xfd6 (0x7624a99214a6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/35: torch::autograd::Engine::thread_main(std::shared_ptrtorch::autograd::GraphTask const&) + 0x56c (0x7624a99224ec in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/36: torch::autograd::Engine::thread_init(int, std::shared_ptrtorch::autograd::ReadyQueue const&, bool) + 0x2ad (0x7624a991aa7d in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/37: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptrtorch::autograd::ReadyQueue const&, bool) + 0x75 (0x7624b21ae685 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/38: + 0xdc253 (0x7624bf198253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame https://github.com/Lightning-AI/lightning-thunder/pull/39: + 0x94ac3 (0x7624bf37eac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame https://github.com/Lightning-AI/lightning-thunder/pull/40: + 0x126850 (0x7624bf410850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Traceback (most recent call last):
File "/opt/pytorch/lightning-thunder/test.py", line 28, in
out.loss.backward()
File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 624, in backward
torch.autograd.backward(
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/init.py", line 347, in backward
_engine_run_backward(
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 307, in apply
return user_fn(self, *args)
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 600, in wrapper
outputs = fn(ctx, *args)
File "/opt/pytorch/lightning-thunder/thunder/executors/torch_autograd.py", line 96, in backward
grads = ctx.compiled_backward([saved_tensors_list, ctx.saved_other], args)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
return func(*args, **kwargs)
File "thunder.backward_fn_140", line 451, in backward_fn
File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 456, in call
return fd.execute(args, **kwargs)
File "/opt/pytorch/nvfuser/nvfuser/init.py", line 181, in execute
results = self._execute(
RuntimeError: INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/runtime/fusion_kernel_runtime.cpp":358, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Detected exception while compiling fusion segments in parallel. Error messages from all threads are printed below.

Error from segmentation group 10: INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/index_compute.cpp":1965, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Couldn't find allocation mapping for T71_l___bfloat[ iblockIdx.x488{( ceilDiv(( ceilDiv(128, 4) ), blockDim.x) )}, iblockIdx.y491{( ceilDiv(( 1 * 597 ), 1) )}, iUS492{1}, iV487{4}, ithreadIdx.x489{blockDim.x} ] ca_pos( 3 ) dim: 2 id: iS301{128}, loops: iblockIdx.x335{( ceilDiv(( ceilDiv(4096, 4) ), blockDim.x) )} iblockIdx.y435{( ceilDiv(( 1 * 597 ), 1) )} iUS436{1} iV487{4} ithreadIdx.x489{blockDim.x}
Exception raised from getNonGlobalConsumerStridedIndices at /opt/pytorch/nvfuser/csrc/index_compute.cpp:1965 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&) + 0xf3 (0x7623932bd099 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&) + 0x53 (0x76239366f303 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/2: + 0x5d9518 (0x762393758518 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/3: + 0x5d981f (0x76239375881f in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/4: + 0x5d9e70 (0x762393758e70 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/5: + 0x41df4e (0x76239359cf4e in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/6: + 0x425aeb (0x7623935a4aeb in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/7: + 0x4273bf (0x7623935a63bf in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/8: + 0x4243ef (0x7623935a33ef in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/9: + 0x4243ef (0x7623935a33ef in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/10: + 0x4243ef (0x7623935a33ef in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/11: + 0x4273bf (0x7623935a63bf in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/12: + 0x4243ef (0x7623935a33ef in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/13: + 0x4243ef (0x7623935a33ef in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/14: + 0x41d8bb (0x76239359c8bb in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/15: + 0x3eb0ae (0x76239356a0ae in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/16: nvfuser::GpuLower::run() + 0x239 (0x762393565fe9 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/17: nvfuser::FusionExecutor::compileFusion(nvfuser::Fusion*, nvfuser::KernelArgumentHolder const&, nvfuser::LaunchParams const&, nvfuser::CompileParams, nvfuser::SchedulerType, long, long, long, long) + 0xa8a (0x762393938b1a in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/18: + 0x7f1572 (0x762393970572 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/19: + 0x7f183c (0x76239397083c in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/20: c10::ThreadPool::main_loop(unsigned long) + 0x2bd (0x762478d87fbd in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/21: + 0xdc253 (0x7624bf198253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame https://github.com/Lightning-AI/lightning-thunder/pull/22: + 0x94ac3 (0x7624bf37eac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame https://github.com/Lightning-AI/lightning-thunder/pull/23: + 0x126850 (0x7624bf410850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Use NVFUSER_DISABLE=parallel_compile to simplify error message.
Exception raised from compileFusionParallel at /opt/pytorch/nvfuser/csrc/runtime/fusion_kernel_runtime.cpp:358 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&) + 0xf3 (0x7623932bd099 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&) + 0x53 (0x76239366f303 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/2: nvfuser::FusionKernelRuntime::compileFusionParallel(nvfuser::KernelArgumentHolder) + 0xf9d (0x76239397213d in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/3: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRefc10::IValue const&, std::optionalnvfuser::PrimDataType, std::optional) + 0x1cb (0x76239396befb in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/4: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRefc10::IValue const&, std::optional, bool, bool, bool) const + 0x796 (0x762393ad9d76 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/5: + 0x1ca17e (0x76239334917e in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/6: + 0x24838f (0x7623933c738f in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/7: + 0x2dd850 (0x76239345c850 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)

frame https://github.com/Lightning-AI/lightning-thunder/pull/32: torch::autograd::PyNode::apply(std::vector<at::Tensor, std::allocatorat::Tensor >&&) + 0x95 (0x7624b21b3c75 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame https://github.com/Lightning-AI/lightning-thunder/issues/33: + 0x4fb908b (0x7624a992708b in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/34: torch::autograd::Engine::evaluate_function(std::shared_ptrtorch::autograd::GraphTask&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptrtorch::autograd::ReadyQueue const&) + 0xfd6 (0x7624a99214a6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/35: torch::autograd::Engine::thread_main(std::shared_ptrtorch::autograd::GraphTask const&) + 0x56c (0x7624a99224ec in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/36: torch::autograd::Engine::thread_init(int, std::shared_ptrtorch::autograd::ReadyQueue const&, bool) + 0x2ad (0x7624a991aa7d in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/37: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptrtorch::autograd::ReadyQueue const&, bool) + 0x75 (0x7624b21ae685 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame https://github.com/Lightning-AI/lightning-thunder/pull/38: + 0xdc253 (0x7624bf198253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame https://github.com/Lightning-AI/lightning-thunder/pull/39: + 0x94ac3 (0x7624bf37eac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame https://github.com/Lightning-AI/lightning-thunder/pull/40: + 0x126850 (0x7624bf410850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Related thunder issue - Lightning-AI/lightning-thunder#1293 and related comment - Lightning-AI/lightning-thunder#1293 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant