Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merging IterDomains requires that their iteration types match. #2317

Closed
wujingyue opened this issue May 29, 2024 · 7 comments · Fixed by #2326
Closed

Merging IterDomains requires that their iteration types match. #2317

wujingyue opened this issue May 29, 2024 · 7 comments · Fixed by #2326
Assignees

Comments

@wujingyue
Copy link
Collaborator

wujingyue commented May 29, 2024

Check out wjy/linear and run NVFUSER_DISABLE=parallel_compile python repro.py.

Traceback (most recent call last):
  File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 139, in execute
    result = self._execute(
RuntimeError: Merging IterDomains requires that their iteration types match. Outer: iS284{( ceilDiv(1600, 32) )}, Inner: rS257{i6}
Exception raised from merge at /opt/pytorch/nvfuser/csrc/ir/nodes.cpp:2565 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf3 (0x7f7d3928490d in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #1: nvfuser::IterDomain::merge(nvfuser::IterDomain*, nvfuser::IterDomain*, bool) + 0x38a (0x7f7d396e94ea in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #2: <unknown function> + 0x58364d (0x7f7d396e964d in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #3: nvfuser::TensorView::merge(long, long) + 0x106 (0x7f7d3989c796 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #4: nvfuser::scheduleTranspose(nvfuser::Fusion*, nvfuser::TransposeParams) + 0xe9f (0x7f7d39866e2f in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #5: <unknown function> + 0x703551 (0x7f7d39869551 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x5c4166 (0x7f7d3972a166 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #7: nvfuser::FusionKernelRuntime::compileFusionParallel(nvfuser::KernelArgumentHolder) + 0x447 (0x7f7d39731bc7 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #8: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0xad3 (0x7f7d3973d7e3 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #9: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, bool, bool, std::optional<signed char>) const + 0x3c8 (0x7f7d39920ab8 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #10: <unknown function> + 0x18bae5 (0x7f7d392f1ae5 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #11: <unknown function> + 0x1ffb02 (0x7f7d39365b02 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #12: <unknown function> + 0x2873f0 (0x7f7d393ed3f0 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
<omitting python frames>
frame #28: <unknown function> + 0x29d90 (0x7f7e48d77d90 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #29: __libc_start_main + 0x80 (0x7f7e48d77e40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

This happened after I rebased https://github.com/Lightning-AI/lightning-thunder/tree/wjy/sharded for #2199. I suspect 3D linear isn't not handled so well as reshape+2D_linear+reshape.

@wujingyue
Copy link
Collaborator Author

FWIW, https://github.com/Lightning-AI/lightning-thunder/blob/bc3925be04e7ec58d9b24fb7ac55fbe007862a65/thunder/tests/opinfos.py#L6095-L6117 could be enhanced to test 3D. Currently, it only tests 1D and 2D input shapes.

@Priya2698
Copy link
Collaborator

FWIW, https://github.com/Lightning-AI/lightning-thunder/blob/bc3925be04e7ec58d9b24fb7ac55fbe007862a65/thunder/tests/opinfos.py#L6095-L6117 could be enhanced to test 3D. Currently, it only tests 1D and 2D input shapes.

Agreed, we do test for 3D cases in our tests though:

def linear_input_generator(
op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs
):
make_arg = partial(
make_tensor,
dtype=dtype,
device="cuda",
low=None,
high=None,
requires_grad=requires_grad,
)
B = 64
M = 512
N = 256
K = 32
# Cases without bias
shapes_input = ((K), (M, K), (B, M, K), (B, 1, M, K))
shapes_weight = ((K), (N, K), (1, K))
for shape_input, shape_weight in itertools.product(shapes_input, shapes_weight):
yield SampleInput(make_arg(shape_input), make_arg(shape_weight))
# Cases with bias
shape_weight = (N, K)
shapes_bias = ((), (N,))
for shape_input, shape_bias in itertools.product(shapes_input, shapes_bias):
yield SampleInput(
make_arg(shape_input), make_arg(shape_weight), make_arg(shape_bias)
)
.

We likely need more tests such as the reproducer of this issue and other larger fusions.

@Priya2698
Copy link
Collaborator

Priya2698 commented May 31, 2024

The issue seems related to: Issue #1659.

The segment for transpose scheduler:

g{(transpose)
inputs: 
T22_g[ iS61{16}, iS62{128}, iS63{1600} ] float
T60_g[ iS175{16}, iS176{128}, iS253{1600}, rS254{i6} ] __bfloat
outputs: 
T68_g[ iS200{16}, iS201{128}, iS202{1600} ] float


T61_l[ iS179{16}, iS180{128}, iS181{1600} ]
   = rng_uniform_range({16, 128, 1600}, double(0), double(1), __bfloat);
(58)
T62_g[ iS182{16}, iS183{128}, iS184{1600} ]
   = __bfloat2float(T61_l[ iS179{16}, iS180{128}, iS181{1600} ]);
(59)
T63_g[ iS185{16}, iS186{128}, iS187{1600} ]
   = T62_g[ iS182{16}, iS183{128}, iS184{1600} ]
   < double(0.90000000000000002);
(60)
T65_g[ iS191{16}, iS192{128}, iS193{1600} ]
   = (float)(T63_g[ iS185{16}, iS186{128}, iS187{1600} ]);
(62)
T64_g[ iS188{16}, iS189{128}, iS255{1600} ]
   = __bfloat2float(T60_g[ iS175{16}, iS176{128}, iS253{1600}, rS254{i6} ]);
(61)
T66_g[ iS194{16}, iS195{128}, iS196{1600} ]
   = T64_g[ iS188{16}, iS189{128}, iS255{1600} ]
   * T65_g[ iS191{16}, iS192{128}, iS193{1600} ];
(63)
T67_g[ iS197{16}, iS198{128}, iS199{1600} ]
   = T66_g[ iS194{16}, iS195{128}, iS196{1600} ]
   * double(1.11111);
(64)
T68_g[ iS200{16}, iS201{128}, iS202{1600} ]
   = T22_g[ iS61{16}, iS62{128}, iS63{1600} ]
   + T67_g[ iS197{16}, iS198{128}, iS199{1600} ];
(65)
}

@wujingyue
Copy link
Collaborator Author

FYI, @Priya2698 , https://github.com/Lightning-AI/lightning-thunder/tree/wjy/bug2317 is the Thunder branch to reproduce the bug.

$ pytest thunder/benchmarks/targets.py -k test_nanogpt_block_fwd[thunder] -s

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ Captured log call ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
ERROR    nvfuser:__init__.py:205 An error occurred while executing nvFuser FusionDefinition 1.
If you believe this is a bug or need assistance, please file an issue at https://github.com/NVIDIA/Fuser/issues/new
Here's a script to reproduce the error:
```python
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T3 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T4 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T5 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T6 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T7 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T8 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T9 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
    T10 = fd.ops.permute(T9, dims=[0, 2, 1, 3])
    T11 = fd.ops.stride_order(T10, stride_order=[3, 2, 1, 0])
    S12 = fd.define_scalar(16, dtype=DataType.Int)
    S13 = fd.define_scalar(128, dtype=DataType.Int)
    S14 = fd.define_scalar(1600, dtype=DataType.Int)
    V15 = fd.define_vector([S12, S13, S14], dtype=DataType.Int)
    T16 = fd.ops.reshape(T11, new_shape=V15)
    T17 = fd.ops.linear(T16, T1, T0)
    S18 = fd.define_scalar(0.00000, dtype=DataType.Double)
    S19 = fd.define_scalar(1.00000, dtype=DataType.Double)
    S20 = fd.define_scalar(16, dtype=DataType.Int)
    S21 = fd.define_scalar(128, dtype=DataType.Int)
    S22 = fd.define_scalar(1600, dtype=DataType.Int)
    V23 = fd.define_vector([S20, S21, S22], dtype=DataType.Int)
    T24 = fd.ops.uniform(S18, S19, shape=V23, dtype=DataType.BFloat16)
    S25 = fd.define_scalar(0.900000, dtype=DataType.Double)
    T26 = fd.ops.lt(T24, S25)
    T27 = fd.ops.cast(T17, dtype=DataType.Float)
    T28 = fd.ops.cast(T26, dtype=DataType.Float)
    T29 = fd.ops.mul(T27, T28)
    S30 = fd.define_scalar(1.11111, dtype=DataType.Double)
    T31 = fd.ops.mul(T29, S30)
    T32 = fd.ops.cast(T8, dtype=DataType.Float)
    T33 = fd.ops.add(T32, T31)
    T34, T35 = fd.ops.var_mean(T33, dims=[2], correction=0, keepdim=False)
    S36 = fd.define_scalar(16, dtype=DataType.Int)
    S37 = fd.define_scalar(128, dtype=DataType.Int)
    S38 = fd.define_scalar(1, dtype=DataType.Int)
    V39 = fd.define_vector([S36, S37, S38], dtype=DataType.Int)
    T40 = fd.ops.broadcast_in_dim(T34, shape=V39, broadcast_dims=[0, 1])
    S41 = fd.define_scalar(16, dtype=DataType.Int)
    S42 = fd.define_scalar(128, dtype=DataType.Int)
    S43 = fd.define_scalar(1, dtype=DataType.Int)
    V44 = fd.define_vector([S41, S42, S43], dtype=DataType.Int)
    T45 = fd.ops.broadcast_in_dim(T35, shape=V44, broadcast_dims=[0, 1])
    S46 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
    T47 = fd.ops.add(T40, S46)
    T48 = fd.ops.rsqrt(T47)
    S49 = fd.define_scalar(16, dtype=DataType.Int)
    S50 = fd.define_scalar(128, dtype=DataType.Int)
    S51 = fd.define_scalar(1600, dtype=DataType.Int)
    V52 = fd.define_vector([S49, S50, S51], dtype=DataType.Int)
    T53 = fd.ops.broadcast_in_dim(T45, shape=V52, broadcast_dims=[0, 1, 2])
    T54 = fd.ops.sub(T33, T53)
    S55 = fd.define_scalar(16, dtype=DataType.Int)
    S56 = fd.define_scalar(128, dtype=DataType.Int)
    S57 = fd.define_scalar(1600, dtype=DataType.Int)
    V58 = fd.define_vector([S55, S56, S57], dtype=DataType.Int)
    T59 = fd.ops.broadcast_in_dim(T48, shape=V58, broadcast_dims=[0, 1, 2])
    T60 = fd.ops.mul(T54, T59)
    S61 = fd.define_scalar(16, dtype=DataType.Int)
    S62 = fd.define_scalar(128, dtype=DataType.Int)
    S63 = fd.define_scalar(1600, dtype=DataType.Int)
    V64 = fd.define_vector([S61, S62, S63], dtype=DataType.Int)
    T65 = fd.ops.broadcast_in_dim(T3, shape=V64, broadcast_dims=[2])
    T66 = fd.ops.cast(T65, dtype=DataType.Float)
    T67 = fd.ops.mul(T60, T66)
    S68 = fd.define_scalar(16, dtype=DataType.Int)
    S69 = fd.define_scalar(128, dtype=DataType.Int)
    S70 = fd.define_scalar(1600, dtype=DataType.Int)
    V71 = fd.define_vector([S68, S69, S70], dtype=DataType.Int)
    T72 = fd.ops.broadcast_in_dim(T2, shape=V71, broadcast_dims=[2])
    T73 = fd.ops.cast(T72, dtype=DataType.Float)
    T74 = fd.ops.add(T67, T73)
    T75 = fd.ops.cast(T74, dtype=DataType.BFloat16)
    T76 = fd.ops.linear(T75, T5, T4)
    T77 = fd.ops.cast(T76, dtype=DataType.Float)
    T78 = fd.ops.mul(T77, T77)
    T79 = fd.ops.mul(T78, T77)
    S80 = fd.define_scalar(0.500000, dtype=DataType.Double)
    T81 = fd.ops.mul(S80, T77)
    S82 = fd.define_scalar(0.0447150, dtype=DataType.Double)
    T83 = fd.ops.mul(S82, T79)
    T84 = fd.ops.add(T77, T83)
    S85 = fd.define_scalar(0.797885, dtype=DataType.Double)
    T86 = fd.ops.mul(S85, T84)
    T87 = fd.ops.tanh(T86)
    S88 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T89 = fd.ops.add(S88, T87)
    T90 = fd.ops.mul(T81, T89)
    T91 = fd.ops.cast(T90, dtype=DataType.BFloat16)
    T92 = fd.ops.linear(T91, T7, T6)
    S93 = fd.define_scalar(0.00000, dtype=DataType.Double)
    S94 = fd.define_scalar(1.00000, dtype=DataType.Double)
    S95 = fd.define_scalar(16, dtype=DataType.Int)
    S96 = fd.define_scalar(128, dtype=DataType.Int)
    S97 = fd.define_scalar(1600, dtype=DataType.Int)
    V98 = fd.define_vector([S95, S96, S97], dtype=DataType.Int)
    T99 = fd.ops.uniform(S93, S94, shape=V98, dtype=DataType.BFloat16)
    S100 = fd.define_scalar(0.900000, dtype=DataType.Double)
    T101 = fd.ops.lt(T99, S100)
    T102 = fd.ops.cast(T92, dtype=DataType.Float)
    T103 = fd.ops.cast(T101, dtype=DataType.Float)
    T104 = fd.ops.mul(T102, T103)
    S105 = fd.define_scalar(1.11111, dtype=DataType.Double)
    T106 = fd.ops.mul(T104, S105)
    T107 = fd.ops.add(T33, T106)
    T108 = fd.ops.cast(T107, dtype=DataType.BFloat16)
    fd.add_output(T108)

with FusionDefinition() as fd:
    nvfuser_fusion_id1(fd)

inputs = [
    torch.randn((1600,), dtype=torch.bfloat16, device='cuda:0').as_strided((1600,), (1,)),
    torch.randn((2560000,), dtype=torch.bfloat16, device='cuda:0').as_strided((1600, 1600), (1600, 1)),
    torch.randn((1600,), dtype=torch.bfloat16, device='cuda:0').as_strided((1600,), (1,)),
    torch.randn((1600,), dtype=torch.bfloat16, device='cuda:0').as_strided((1600,), (1,)),
    torch.randn((6400,), dtype=torch.bfloat16, device='cuda:0').as_strided((6400,), (1,)),
    torch.randn((10240000,), dtype=torch.bfloat16, device='cuda:0').as_strided((6400, 1600), (1600, 1)),
    torch.randn((1600,), dtype=torch.bfloat16, device='cuda:0').as_strided((1600,), (1,)),
    torch.randn((10240000,), dtype=torch.bfloat16, device='cuda:0').as_strided((1600, 6400), (6400, 1)),
    torch.randn((3276800,), dtype=torch.bfloat16, device='cuda:0').as_strided((16, 128, 1600), (204800, 1600, 1)),
    torch.randn((3276800,), dtype=torch.bfloat16, device='cuda:0').as_strided((16, 25, 128, 64), (204800, 8192, 64, 1)),
]
fd.execute(inputs)

Traceback (most recent call last):
File "/opt/pytorch/nvfuser/nvfuser/init.py", line 145, in execute
result = self._execute(
RuntimeError: !detect_exception_in_thread_pool.load() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/kernel_cache.cpp":1336, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Detected exception while compiling fusion segments in parallel. Error messages from all threads are printed below.

Error from segmentation group 5: Merging IterDomains requires that their iteration types match. Outer: iS284{( ceilDiv(1600, 32) )}, Inner: rS257{i6}
Exception raised from merge at /opt/pytorch/nvfuser/csrc/ir/nodes.cpp:2558 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&) + 0xf3 (0x727dc0dfca67 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #1: nvfuser::IterDomain::merge(nvfuser::IterDomain*, nvfuser::IterDomain*, bool) + 0x38a (0x727dc12d2a0a in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #2: + 0x5fcb6d (0x727dc12d2b6d in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #3: nvfuser::TensorView::merge(long, long) + 0x106 (0x727dc14a27b6 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #4: nvfuser::scheduleTranspose(nvfuser::Fusion*, nvfuser::TransposeParams) + 0xe97 (0x727dc146d037 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #5: + 0x799751 (0x727dc146f751 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #6: + 0x64dbc6 (0x727dc1323bc6 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #7: + 0x64de4c (0x727dc1323e4c in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #8: c10::ThreadPool::main_loop(unsigned long) + 0x2bd (0x727ed0d87c7d in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #9: + 0xdc253 (0x727ede9f2253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #10: + 0x94ac3 (0x727edebd8ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #11: + 0x126850 (0x727edec6a850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Use NVFUSER_DISABLE=parallel_compile to simplify error message.
Exception raised from compileFusionParallel at /opt/pytorch/nvfuser/csrc/kernel_cache.cpp:1336 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&) + 0xf3 (0x727dc0dfca67 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&) + 0x53 (0x727dc112d3b3 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #2: nvfuser::FusionKernelRuntime::compileFusionParallel(nvfuser::KernelArgumentHolder) + 0x1600 (0x727dc132e020 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #3: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRefc10::IValue const&, std::optionalnvfuser::PrimDataType, std::optional) + 0xad3 (0x727dc1337ce3 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #4: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRefc10::IValue const&, std::optional, bool, bool, bool) const + 0x3ec (0x727dc15283fc in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #5: + 0x19e88e (0x727dc0e7488e in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #6: + 0x2153ff (0x727dc0eeb3ff in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #7: + 0x2a9be0 (0x727dc0f7fbe0 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #8: + 0x15a10e (0x6284c628710e in /usr/bin/python3)
frame #9: _PyObject_MakeTpCall + 0x25b (0x6284c627da7b in /usr/bin/python3)
frame #10: + 0x168acb (0x6284c6295acb in /usr/bin/python3)
frame #11: _PyEval_EvalFrameDefault + 0x198c (0x6284c627153c in /usr/bin/python3)
frame #12: + 0x1687f1 (0x6284c62957f1 in /usr/bin/python3)
frame #13: PyObject_Call + 0x122 (0x6284c6296492 in /usr/bin/python3)
frame #14: _PyEval_EvalFrameDefault + 0x2a27 (0x6284c62725d7 in /usr/bin/python3)
frame #15: _PyObject_FastCallDictTstate + 0xc4 (0x6284c627cc14 in /usr/bin/python3)
frame #16: _PyObject_Call_Prepend + 0xc1 (0x6284c62928d1 in /usr/bin/python3)
frame #17: + 0x280700 (0x6284c63ad700 in /usr/bin/python3)
frame #18: _PyObject_MakeTpCall + 0x25b (0x6284c627da7b in /usr/bin/python3)
frame #19: _PyEval_EvalFrameDefault + 0x64e6 (0x6284c6276096 in /usr/bin/python3)
frame #20: _PyFunction_Vectorcall + 0x7c (0x6284c62879fc in /usr/bin/python3)
frame #21: _PyEval_EvalFrameDefault + 0x2a27 (0x6284c62725d7 in /usr/bin/python3)
frame #22: _PyFunction_Vectorcall + 0x7c (0x6284c62879fc in /usr/bin/python3)
frame #23: _PyEval_EvalFrameDefault + 0x2a27 (0x6284c62725d7 in /usr/bin/python3)
frame #24: _PyFunction_Vectorcall + 0x7c (0x6284c62879fc in /usr/bin/python3)
frame #25: _PyEval_EvalFrameDefault + 0x2a27 (0x6284c62725d7 in /usr/bin/python3)
frame #26: _PyFunction_Vectorcall + 0x7c (0x6284c62879fc in /usr/bin/python3)
frame #27: _PyEval_EvalFrameDefault + 0x2a27 (0x6284c62725d7 in /usr/bin/python3)
frame #28: _PyFunction_Vectorcall + 0x7c (0x6284c62879fc in /usr/bin/python3)
frame #29: _PyEval_EvalFrameDefault + 0x2a27 (0x6284c62725d7 in /usr/bin/python3)
frame #30: + 0x16893e (0x6284c629593e in /usr/bin/python3)
frame #31: _PyEval_EvalFrameDefault + 0x2a27 (0x6284c62725d7 in /usr/bin/python3)
frame #32: + 0x16893e (0x6284c629593e in /usr/bin/python3)
frame #33: _PyEval_EvalFrameDefault + 0x2a27 (0x6284c62725d7 in /usr/bin/python3)
frame #34: _PyObject_FastCallDictTstate + 0xc4 (0x6284c627cc14 in /usr/bin/python3)
frame #35: _PyObject_Call_Prepend + 0x5c (0x6284c629286c in /usr/bin/python3)
frame #36: + 0x280700 (0x6284c63ad700 in /usr/bin/python3)
frame #37: PyObject_Call + 0xbb (0x6284c629642b in /usr/bin/python3)
frame #38: _PyEval_EvalFrameDefault + 0x2a27 (0x6284c62725d7 in /usr/bin/python3)
frame #39: _PyFunction_Vectorcall + 0x7c (0x6284c62879fc in /usr/bin/python3)
frame #40: _PyEval_EvalFrameDefault + 0x2a27 (0x6284c62725d7 in /usr/bin/python3)
frame #41: _PyFunction_Vectorcall + 0x7c (0x6284c62879fc in /usr/bin/python3)
frame #42: _PyEval_EvalFrameDefault + 0x6bd (0x6284c627026d in /usr/bin/python3)
frame #43: _PyFunction_Vectorcall + 0x7c (0x6284c62879fc in /usr/bin/python3)
frame #44: _PyEval_EvalFrameDefault + 0x8ac (0x6284c627045c in /usr/bin/python3)
frame #45: + 0x16893e (0x6284c629593e in /usr/bin/python3)
frame #46: _PyEval_EvalFrameDefault + 0x2a27 (0x6284c62725d7 in /usr/bin/python3)
frame #47: _PyObject_FastCallDictTstate + 0xc4 (0x6284c627cc14 in /usr/bin/python3)
frame #48: _PyObject_Call_Prepend + 0x5c (0x6284c629286c in /usr/bin/python3)
frame #49: + 0x280700 (0x6284c63ad700 in /usr/bin/python3)
frame #50: PyObject_Call + 0xbb (0x6284c629642b in /usr/bin/python3)
frame #51: _PyEval_EvalFrameDefault + 0x2a27 (0x6284c62725d7 in /usr/bin/python3)
frame #52: _PyFunction_Vectorcall + 0x7c (0x6284c62879fc in /usr/bin/python3)
frame #53: PyObject_Call + 0x122 (0x6284c6296492 in /usr/bin/python3)
frame #54: _PyEval_EvalFrameDefault + 0x2a27 (0x6284c62725d7 in /usr/bin/python3)
frame #55: _PyFunction_Vectorcall + 0x7c (0x6284c62879fc in /usr/bin/python3)
frame #56: _PyEval_EvalFrameDefault + 0x2a27 (0x6284c62725d7 in /usr/bin/python3)
frame #57: _PyFunction_Vectorcall + 0x7c (0x6284c62879fc in /usr/bin/python3)
frame #58: _PyEval_EvalFrameDefault + 0x614a (0x6284c6275cfa in /usr/bin/python3)
frame #59: + 0x1687f1 (0x6284c62957f1 in /usr/bin/python3)
frame #60: _PyEval_EvalFrameDefault + 0x614a (0x6284c6275cfa in /usr/bin/python3)
frame #61: _PyFunction_Vectorcall + 0x7c (0x6284c62879fc in /usr/bin/python3)
frame #62: _PyObject_FastCallDictTstate + 0x16d (0x6284c627ccbd in /usr/bin/python3)
frame #63: _PyObject_Call_Prepend + 0x5c (0x6284c629286c in /usr/bin/python3)
======================================================================================================================================================================================================================================= short test summary info =======================================================================================================================================================================================================================================
FAILED thunder/benchmarks/targets.py::test_nanogpt_block_fwd[thunder] - RuntimeError: !detect_exception_in_thread_pool.load() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/kernel_cache.cpp":1336, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Detected exception while compiling fusion segments in parallel. Error messages from all threads are printed below.
============================================================================================================================================================================================================================ 1 failed, 339 deselected, 6 warnings in 7.62s ============================================================================================================================================================================================================================

@wujingyue
Copy link
Collaborator Author

wujingyue commented Jun 6, 2024

Update: never mind. I should have run test_nanogpt_block[inference-thunder] instead. forward-thunder is the forward pass in training, which is different from the previous benchmark.

Interestingly, I'm unable to reproduce the problem after Lightning-AI/lightning-thunder@d1b016a, authored by @IvanYashchuk .

$ git log
commit 81425474dad41afa1f3100efea63faa8fd062a68 (HEAD -> wjy/bug2317)
Author: Jingyue Wu <[email protected]>
Date:   Thu May 9 20:39:36 2024 +0000

    Unconditionally enable linear and matmul and turn off nv_enable_bookend.

commit d1b016a58a48e5c6282622de488be8c9135dd821
Author: Ivan Yashchuk <[email protected]>
Date:   Mon Jun 3 14:56:03 2024 +0300

    Update benchmarks/targets.py: inference/forward/backward parametrization (#498)

commit 3bd0100218055ad3713466d8fc03647928f7a289
Author: Masaki Kozuki <[email protected]>
Date:   Mon Jun 3 18:19:57 2024 +0900

    Remove trace dump for debug from test_tensor_parallel.py (#509)
$ pytest thunder/benchmarks/targets.py -k test_nanogpt_block[forward-thunder] -s
========================================================================================================================================================================================================================================= test session starts =========================================================================================================================================================================================================================================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.5.0
Test order randomisation NOT enabled. Enable with --random-order or --random-order-bucket=<bucket_type>
benchmark: 4.0.0 (defaults: timer=torch.utils.benchmark.utils.timer.timer disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=True warmup_iterations=100000)
rootdir: /opt/pytorch/lightning-thunder
configfile: pyproject.toml
plugins: timestamper-0.0.10, xdist-3.5.0, random-order-1.1.1, cov-4.1.0, benchmark-4.0.0, hypothesis-6.100.0, timeout-2.2.0, anyio-4.3.0, shard-0.1.2
timeout: 900.0s
timeout method: signal
timeout func_only: False
collected 644 items / 643 deselected / 1 selected
Running 1 items in this shard

thunder/benchmarks/targets.py .


------------------------------------------------------ benchmark: 1 tests ------------------------------------------------------
Name (time in ms)                          Min       Max    Mean  StdDev  Median     IQR  Outliers       OPS  Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------
test_nanogpt_block[forward-thunder]     1.2142  123.1338  1.5354  4.2465  1.3960  0.0048      1;77  651.3001     822           1
--------------------------------------------------------------------------------------------------------------------------------

Legend:
  Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile.
  OPS: Operations Per Second, computed as 1 / Mean
========================================================================================================================================================================================================================== 1 passed, 643 deselected, 2376 warnings in 11.46s ==========================================================================================================================================================================================================================

@wujingyue
Copy link
Collaborator Author

FYI, @Priya2698 , I synced https://github.com/Lightning-AI/lightning-thunder/tree/wjy/bug2317. You can still reproduce the same problem using pytest thunder/benchmarks/targets.py -k test_nanogpt_block[inference-thunder] -s.

@jacobhinkle
Copy link
Collaborator

This is happening when we segment at a reduction output. In the consumer segments, the edge is converted to an input that has a Reduction domain. Instead, I think we should filter out Reduction domains in convertInputRFactorsToRoots (and adjust stride order accordingly).

Priya2698 added a commit that referenced this issue Jun 25, 2024
Issue #2317.

The issue arises in the following lines for reference 1: `[I0, I1, I2,
r3]`:

After tiling:
```
  reference1->split(inner_most_pos1_in_ref1, params.tile_size1);
  reference1->reorder({{inner_most_pos1_in_ref1 + 1, -1}});
  reference1->split(inner_most_pos2_in_ref1, params.tile_size2);
  reference1->reorder({{inner_most_pos2_in_ref1 + 1, -1}});
```
Reference 1 is: [I0, I1/tile1, I2/tile2, r3, tile1, tile2]
```
 // Merge remaining dimensions
  int64_t lhs_i = -1;
  for (int64_t i = reference1->nDims() - 2; i > 0; i--) {
    auto axis_i = i - 1;
    if (lhs_i == -1) {
      lhs_i = axis_i;
    } else {
      reference1->merge(axis_i, lhs_i);
      lhs_i = axis_i;
    }
```
This tries to merge a reduction iterdomain with iteration type
iterdomain.

This PR ignored the reduction axis when merging all non-tile dimensions.
protonu pushed a commit that referenced this issue Jun 26, 2024
Issue #2317.

The issue arises in the following lines for reference 1: `[I0, I1, I2,
r3]`:

After tiling:
```
  reference1->split(inner_most_pos1_in_ref1, params.tile_size1);
  reference1->reorder({{inner_most_pos1_in_ref1 + 1, -1}});
  reference1->split(inner_most_pos2_in_ref1, params.tile_size2);
  reference1->reorder({{inner_most_pos2_in_ref1 + 1, -1}});
```
Reference 1 is: [I0, I1/tile1, I2/tile2, r3, tile1, tile2]
```
 // Merge remaining dimensions
  int64_t lhs_i = -1;
  for (int64_t i = reference1->nDims() - 2; i > 0; i--) {
    auto axis_i = i - 1;
    if (lhs_i == -1) {
      lhs_i = axis_i;
    } else {
      reference1->merge(axis_i, lhs_i);
      lhs_i = axis_i;
    }
```
This tries to merge a reduction iterdomain with iteration type
iterdomain.

This PR ignored the reduction axis when merging all non-tile dimensions.
jacobhinkle added a commit that referenced this issue Jul 1, 2024
Currently, when we segment at the output of a reduction, the consumer
segment will have an input tensor that has a `Reduction` axis in it.
This can be problematic; see #2481 and #2317. This PR strips reduction
axes from the root and allocation domain in these cases.

---------

Co-authored-by: Jingyue Wu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants