From d30dccd21f2c3bc1ed6cd054c131436a1af548e1 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Thu, 6 Oct 2022 17:41:48 +0000 Subject: [PATCH] [mlir][sparse] Favors synthetic tensor over other undefined tensors Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D135385 --- .../lib/Dialect/SparseTensor/Utils/Merger.cpp | 25 ++-- .../Dialect/SparseTensor/MergerTest.cpp | 110 ++++++++++++++---- 2 files changed, 104 insertions(+), 31 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp index 60fab7b5d0070e..187a6c0b188b20 100644 --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -265,21 +265,26 @@ BitVector Merger::simplifyCond(unsigned s0, unsigned p0) { BitVector simple = latPoints[p0].bits; bool reset = isSingleton && hasAnySparse(simple); - unsigned offset = 0; + unsigned be = simple.size(); + unsigned offset = 0; // relative to the end if (!reset) // Starts resetting from a dense dimension, so that the first bit (if kept) // is not undefined dimension type. - for (unsigned b = 0, be = simple.size(); b < be; b++) - if (simple[b] && isDimLevelType(b, DimLvlType::kDense)) - offset = b; + for (unsigned b = 0; b < be; b++) { + if (simple[b] && isDimLevelType(b, DimLvlType::kDense)) { + offset = be - b - 1; // relative to the end + break; + } + } - // Now apply the two basic rules. - for (unsigned b = 0, be = simple.size(); b < be; b++) { - unsigned i = (offset + b) % be; - if (simple[i] && (!isDimLevelType(i, DimLvlType::kCompressed) && - !isDimLevelType(i, DimLvlType::kSingleton))) { + // Now apply the two basic rules. We also iterate the bits reversely to always + // keep the rightmost bit (which could possibly be a synthetic tensor). + for (unsigned b = be - 1 - offset, i = 0; i < be; + b = b == 0 ? be - 1 : b - 1, i++) { + if (simple[b] && (!isDimLevelType(b, DimLvlType::kCompressed) && + !isDimLevelType(b, DimLvlType::kSingleton))) { if (reset) - simple.reset(i); + simple.reset(b); reset = true; } } diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp index 9851f456d4b55b..4e95b8f6b2ebb9 100644 --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -380,15 +380,16 @@ class MergerTest3T1LD : public MergerTestBase { /// /// Tests with both undef and dense input. /// -class MergerTest3T1LU : public MergerTestBase { + +class MergerTest4T1LU : public MergerTestBase { protected: // Our three tensors (two inputs, one output). - const unsigned t0 = 0, t1 = 1, t2 = 2; + const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3; // Our single loop. const unsigned l0 = 0; - MergerTest3T1LU() : MergerTestBase(3, 1) { + MergerTest4T1LU() : MergerTestBase(4, 1) { // Tensor 0: undef input vector. merger.addExp(Kind::kTensor, t0, -1u); merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kUndef)); @@ -397,43 +398,110 @@ class MergerTest3T1LU : public MergerTestBase { merger.addExp(Kind::kTensor, t1, -1u); merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kDense)); - // Tensor 2: dense output vector. + // Tensor 2: undef input vector. merger.addExp(Kind::kTensor, t2, -1u); - merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kDense)); + merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kUndef)); + + // Tensor 3: dense output vector. + merger.addExp(Kind::kTensor, t3, -1u); + merger.setDimLevelFormat(t3, l0, DimLevelFormat(DimLvlType::kDense)); + } +}; + +/// +/// Tests with operation on sparse output. +/// + +class MergerTest3T1L_SO : public MergerTestBase { +protected: + // Our three tensors (two inputs, one output, one synthetic). + const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3; + + // Our single loop. + const unsigned l0 = 0; + + MergerTest3T1L_SO() : MergerTestBase(3, 1) { + merger.setHasSparseOut(true); + + // Tensor 0: undef input vector. + merger.addExp(Kind::kTensor, t0, -1u); + merger.setDimLevelFormat(t0, l0, DimLevelFormat(DimLvlType::kUndef)); + + // Tensor 1: undef input vector. + merger.addExp(Kind::kTensor, t1, -1u); + merger.setDimLevelFormat(t1, l0, DimLevelFormat(DimLvlType::kUndef)); + + // Tensor 2: sparse output vector. + merger.addExp(Kind::kTensor, t2, -1u); + merger.setDimLevelFormat(t2, l0, DimLevelFormat(DimLvlType::kCompressed)); } }; + } // namespace -/// Vector multiplication (conjunction) of 2 vectors, i.e.; -/// a(i) = b(i) * c(i) +/// Vector multiplication (conjunction) of 3 vectors, i.e.; +/// a(i) = b(i) * c(i) * d(i) /// which should form the single lattice point /// { -/// lat( i_00_U i_01_D / (tensor_0 * tensor_1) ) +/// lat( i_00_U i_01_D i_02_U / (tensor_0 * tensor_1 * tensor2) ) /// } /// after optimization, the dense dimesion should be kept, despite it appears -/// after the undef dimension +/// in the middle /// { -/// lat( i_01_D / (tensor_0 * tensor_1) ) +/// lat( i_01_D / (tensor_0 * tensor_1 * tensor2) ) /// } -#define IMPL_MERGER_TEST_CONJ(OP) \ - TEST_F(MergerTest3T1LU, vector_##OP) { \ - auto e = OP##Expr(t0, t1); \ +#define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2) \ + TEST_F(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \ + auto em = CONJ1##Expr(t0, t1); \ + auto e = CONJ2##Expr(em, t2); \ auto p0 = tensorPattern(t0); \ auto p1 = tensorPattern(t1); \ + auto p2 = tensorPattern(t2); \ auto s = merger.buildLattices(e, l0); \ - \ expectNumLatPoints(s, 1); \ - expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ - loopsToBits({{l0, t0}, {l0, t1}})); \ - \ + expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ + loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ s = merger.optimizeSet(s); \ expectNumLatPoints(s, 1); \ - expectLatPoint(s, lat(0), OP##Pattern(p0, p1), loopsToBits({{l0, t1}}), \ - true); \ + expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ + loopsToBits({{l0, t1}}), true); \ } -FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ) -#undef IMPL_MERGER_TEST_CONJ +FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF) + +#undef IMPL_MERGER_TEST_CONJ_CONJ_UNDEF + +/// Vector multiplication (conjunction) of 2 vectors, i.e.; +/// o(i) = b(i) * c(i) * o(i) +/// which should form the single lattice point (note how a synthetic tensor +/// i_03_U is created for the sparse output) +/// { +/// lat( i_00_U i_01_U i_03_U / (tensor_0 * tensor_1 * output_tensor_2) ) +/// } +/// after optimization, the synthetic tensor should be preserved. +/// { +/// lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) ) +/// } +#define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2) \ + TEST_F(MergerTest3T1L_SO, vector_##CONJ1##_##CONJ2) { \ + auto em = CONJ1##Expr(t0, t1); \ + auto e = CONJ2##Expr(em, t2); \ + auto p0 = tensorPattern(t0); \ + auto p1 = tensorPattern(t1); \ + auto p2 = tensorPattern(t2); \ + auto s = merger.buildLattices(e, l0); \ + expectNumLatPoints(s, 1); \ + expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ + loopsToBits({{l0, t0}, {l0, t1}, {l0, t3}})); \ + s = merger.optimizeSet(s); \ + expectNumLatPoints(s, 1); \ + expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ + loopsToBits({{l0, t3}}), true); \ + } + +FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT) + +#undef IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT /// Vector addition (disjunction) of 2 vectors. i.e.; /// a(i) = b(i) + c(i)