Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mismatch in IterDomain Iteration Types in TorchBench functorch_dp_cifar10 #2008

Closed
kevinstephano opened this issue Sep 29, 2022 · 3 comments · Fixed by #2037
Closed

Mismatch in IterDomain Iteration Types in TorchBench functorch_dp_cifar10 #2008

kevinstephano opened this issue Sep 29, 2022 · 3 comments · Fixed by #2037
Assignees
Labels

Comments

@kevinstephano
Copy link
Collaborator

kevinstephano commented Sep 29, 2022

🐛 Describe the bug

The bug is happening during scheduling:

Prescheduled Fusion IR:

Inputs:
  T0_g[ iS0{i0}, iS1{i1}, bS2{1}, bS3{1} ], float
  T1_g[ iS4{i6}, iS5{i7}, bS6{1}, bS7{1} ], float
Outputs:
  T4_g[ iS16{i0}, iS17{i1}, bS18{1}, bS19{1} ], float
  T8_g[ iS32{i0}, iS33{i1}, bS34{1}, bS35{1} ], float

%kernel_math {
T2_l[ iS8{i0}, iS9{i1}, bS10{1}, bS11{1} ]
   = T0_g[ iS0{i0}, iS1{i1}, bS2{1}, bS3{1} ]
   + T1_g[ iS4{i6}, iS5{i7}, bS6{1}, bS7{1} ];
T3_l[ iS12{i0}, iS13{i1}, bS14{1}, bS15{1} ]
   = T2_l[ iS8{i0}, iS9{i1}, bS10{1}, bS11{1} ]
   <= double(0.0000000000000000);
T4_g[ iS16{i0}, iS17{i1}, bS18{1}, bS19{1} ]
   = where(T3_l[ iS12{i0}, iS13{i1}, bS14{1}, bS15{1} ]
  , double(0.0000000000000000)
  , T2_l[ iS8{i0}, iS9{i1}, bS10{1}, bS11{1} ]);
T5_l[ iS20{i0}, iS21{i1}, bS22{1}, bS23{1} ]
   = T4_g[ iS16{i0}, iS17{i1}, bS18{1}, bS19{1} ];
T6_l[ iS24{i0}, iS25{i1}, rS26{1}, rS27{1} ]
   = reduction( T5_l[ iS20{i0}, iS21{i1}, bS22{1}, bS23{1} ], op = add, initial value = double(0.0000000000000000), allreduce = false )
T7_l[ iS28{i0}, iS29{i1}, bS30{1}, bS31{1} ]
   = broadcast( T6_l[ iS24{i0}, iS25{i1}, rS26{1}, rS27{1} ] )
T8_g[ iS32{i0}, iS33{i1}, bS34{1}, bS35{1} ]
   = T7_l[ iS28{i0}, iS29{i1}, bS30{1}, bS31{1} ]
   / double(1.0000000000000000);
}

Error:

RuntimeError: Merging IterDomains requires that their iteration types match.

Repro requires a bump in the devel fork from upstream to pick up the python frontend changes.

Repro:

import torch
from torch._C._nvfuser import FusionDefinition, Fusion, DataType

def nvfuser_fusion_id5(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(symbolic_sizes=[-1, -1, 1, 1], contiguous=[True, True, True, True], dtype=DataType.Float)
    T1 = fd.define_tensor(symbolic_sizes=[-1, -1, 1, 1], contiguous=[True, True, True, True], dtype=DataType.Float)
    T2 = fd.ops.add(T0, T1)
    S3 = fd.define_constant(0.00000)
    T4 = fd.ops.le(T2, S3)
    S5 = fd.define_constant(0.00000)
    T6 = fd.ops.where(T4, S5, T2)
    T7 = fd.ops.cast(T6, dtype=DataType.Float)
    T8 = fd.ops.sum(T7, axes=[3, 2], keepdim=False, dtype=DataType.Null)
    T9 = fd.ops.broadcast_in_dim(T8, output_shape=[64, 512, 1, 1], broadcast_dims=[0, 1])
    S10 = fd.define_constant(1.00000)
    T11 = fd.ops.div(T9, S10)
    fd.add_output(T6)
    fd.add_output(T11)

fs = Fusion()
with FusionDefinition(fs) as fd:
    nvfuser_fusion_id5(fd)

input1 = torch.randn(64, 512, 1, 1, device='cuda', dtype=torch.float32)
input2 = torch.randn(64, 512, 1, 1, device='cuda', dtype=torch.float32)

out = fs.execute([input1, input2])

Versions

Upstream TOT?

@kevinstephano kevinstephano changed the title [Place holder for real title] Bug in TorchBench functorch_dp_cifar10 Mismatch in IterDomain Iteration Types in TorchBench functorch_dp_cifar10 Oct 3, 2022
@naoyam
Copy link
Collaborator

naoyam commented Oct 6, 2022

Just what I'm seeing for now. Here's the fusion math:

%kernel_math {
T2_l[ iS8{i0}, iS9{i1}, bS10{1}, bS11{1} ]
   = T0_g[ iS0{i0}, iS1{i1}, bS2{1}, bS3{1} ]
   + T1_g[ iS4{i6}, iS5{i7}, bS6{1}, bS7{1} ];
T3_l[ iS12{i0}, iS13{i1}, bS14{1}, bS15{1} ]
   = T2_l[ iS8{i0}, iS9{i1}, bS10{1}, bS11{1} ]
   <= double(0);
T4_g[ iS16{i0}, iS17{i1}, bS18{1}, bS19{1} ]
   = where(T3_l[ iS12{i0}, iS13{i1}, bS14{1}, bS15{1} ]
  , double(0)
  , T2_l[ iS8{i0}, iS9{i1}, bS10{1}, bS11{1} ]);
T5_l[ iS20{i0}, iS21{i1}, bS22{1}, bS23{1} ]
   = T4_g[ iS16{i0}, iS17{i1}, bS18{1}, bS19{1} ];
T6_l[ iS24{i0}, iS25{i1}, rS26{1}, rS27{1} ]
   = reduction( T5_l[ iS20{i0}, iS21{i1}, bS22{1}, bS23{1} ], op = add, initial value = double(0), allreduce = false )
T7_l[ iS28{i0}, iS29{i1}, bS30{1}, bS31{1} ]
   = broadcast( T6_l[ iS24{i0}, iS25{i1}, rS26{1}, rS27{1} ] )
T8_g[ iS32{i0}, iS33{i1}, bS34{1}, bS35{1} ]
   = T7_l[ iS28{i0}, iS29{i1}, bS30{1}, bS31{1} ]
   / double(1);
}

And we are using the pointwise scheduler, which seems to be failing due to the reduction-broadcast pattern.

@naoyam
Copy link
Collaborator

naoyam commented Oct 6, 2022

C++ repro:

TEST_F(NVFuserTest, FusionReplayTrivialReductionAndBroadcast_CUDA) {
  auto fusion_ptr = std::make_unique<Fusion>();
  Fusion& fusion = *fusion_ptr;
  FusionGuard fg(fusion_ptr.get());

  std::vector<int64_t> shape({10, 5, 1, 1});
  //std::vector<int64_t> shape({10, 5, 1});

  auto tv0 = makeConcreteTensor(shape);
  fusion.addInput(tv0);

  auto tv1 = add(tv0, IrBuilder::create<Double>(1));
  auto tv2 = sum(tv1, {2, 3});
  //auto tv2 = sum(tv1, {2});
  auto tv3 = broadcast(tv2, {false, false, true});
  fusion.addOutput(tv3);

  fusion.printMath();
  fusion.printKernel();

  tv0->merge(-2, -1)->merge(-2, -1)->merge(-2, -1)->split(0, 4);
  //tv0->merge(-2, -1)->merge(-2, -1)->split(0, 4);

  MaxRootDomainInfoSpanningTree tree(tv0);
  TransformPropagator tp(tv0);
  tree.traverse(&tp);

  fusion.printMath();
  fusion.printKernel();
}

Looks like the problem is replaying propagation from a producer to a consumer when there's more than one trivial reduction. In the above repro, the replay works fine when the number of trivial domains is just 1 (as shown in the comments). I'll look into it more closely.

@csarofeen
Copy link
Owner

Thanks.

naoyam added a commit that referenced this issue Oct 6, 2022
naoyam added a commit that referenced this issue Oct 6, 2022
* Allow non-root trivial reductions

Fixes #2008

Co-authored-by: Christian Sarofeen <[email protected]>
naoyam added a commit that referenced this issue Oct 6, 2022
* Fix vectorize size calculation (#2035)

* Allow non-root trivial reductions (#2037)

* Allow non-root trivial reductions

Fixes #2008

Co-authored-by: Christian Sarofeen <[email protected]>

* Test file cleanup (#2040)

* Move test_gpu.cpp to test_gpu1.cpp

* Split test_gpu1.cpp to test_gpu1.cpp, test_gpu2.cpp and test_gpu3.cpp.

Each file should be up to 10K LoC. New tests should be added to
test_gpu3.cpp until it gets 10K LoC.

Co-authored-by: Gao, Xiang <[email protected]>
Co-authored-by: Christian Sarofeen <[email protected]>
naoyam added a commit that referenced this issue Oct 6, 2022
* Allow non-root trivial reductions (#2037)

* Allow non-root trivial reductions

Fixes #2008

Co-authored-by: Christian Sarofeen <[email protected]>

* Test file cleanup (#2040)

* Move test_gpu.cpp to test_gpu1.cpp

* Split test_gpu1.cpp to test_gpu1.cpp, test_gpu2.cpp and test_gpu3.cpp.

Each file should be up to 10K LoC. New tests should be added to
test_gpu3.cpp until it gets 10K LoC.

* format

* fix merge

* format

Co-authored-by: Christian Sarofeen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants