Skip to content

Commit

Permalink
[CINN]disable float auto simplify (#64791) (#65075)
Browse files Browse the repository at this point in the history
* disable float auto simplify

* fix unit test bug

* fix unit tset bug
  • Loading branch information
phlrain authored Jun 12, 2024
1 parent a623686 commit c00f1de
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 48 deletions.
44 changes: 22 additions & 22 deletions paddle/cinn/backends/ir_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1398,7 +1398,7 @@ void test_cache_read1(void* _args, int32_t num_args)
};
for (int32_t i = 0; i < 32; i += 1) {
for (int32_t j = 0; j < 32; j += 1) {
B[((32 * i) + j)] = (2.00000000f * A_local_temp_buffer[((64 * i) + j)]);
B[((32 * i) + j)] = (A_local_temp_buffer[((64 * i) + j)] * 2.00000000f);
};
};
for (int32_t cache_ax0_0 = 0; cache_ax0_0 < 16; cache_ax0_0 += 1) {
Expand All @@ -1408,7 +1408,7 @@ void test_cache_read1(void* _args, int32_t num_args)
};
for (int32_t i = 0; i < 16; i += 1) {
for (int32_t j = 0; j < 16; j += 1) {
C[((16 * i) + j)] = (1.00000000f + B_local_temp_buffer[((32 * i) + j)]);
C[((16 * i) + j)] = (B_local_temp_buffer[((32 * i) + j)] + 1.00000000f);
};
};
cinn_buffer_free((void*)(0), _B);
Expand Down Expand Up @@ -1480,7 +1480,7 @@ void test_cache_read2(void* _args, int32_t num_args)
for (int32_t i = 0; i < 64; i += 1) {
for (int32_t j = 0; j < 32; j += 1) {
A_local_temp_buffer[((32 * i) + j)] = A[((32 * i) + j)];
B[((32 * i) + j)] = (2.00000000f * A_local_temp_buffer[((32 * i) + j)]);
B[((32 * i) + j)] = (A_local_temp_buffer[((32 * i) + j)] * 2.00000000f);
};
};
cinn_buffer_free((void*)(0), _B);
Expand Down Expand Up @@ -1553,7 +1553,7 @@ void test_cache_write1(void* _args, int32_t num_args)
float* C = ((float*)(_C->memory));
for (int32_t i = 0; i < 64; i += 1) {
for (int32_t j = 0; j < 32; j += 1) {
B_local_temp_buffer[((32 * i) + j)] = (2.00000000f * A[((32 * i) + j)]);
B_local_temp_buffer[((32 * i) + j)] = (A[((32 * i) + j)] * 2.00000000f);
};
};
for (int32_t cache_ax0 = 0; cache_ax0 < 64; cache_ax0 += 1) {
Expand All @@ -1563,7 +1563,7 @@ void test_cache_write1(void* _args, int32_t num_args)
};
for (int32_t i = 0; i < 64; i += 1) {
for (int32_t j = 0; j < 32; j += 1) {
C_local_temp_buffer[((32 * i) + j)] = (1.00000000f + B[((32 * i) + j)]);
C_local_temp_buffer[((32 * i) + j)] = (B[((32 * i) + j)] + 1.00000000f);
};
};
for (int32_t cache_ax0_0 = 0; cache_ax0_0 < 64; cache_ax0_0 += 1) {
Expand Down Expand Up @@ -1637,7 +1637,7 @@ void test_cache_write2(void* _args, int32_t num_args)
float* B = ((float*)(_B->memory));
for (int32_t cache_ax0 = 0; cache_ax0 < 64; cache_ax0 += 1) {
for (int32_t cache_ax1 = 0; cache_ax1 < 32; cache_ax1 += 1) {
B_local_temp_buffer[((32 * cache_ax0) + cache_ax1)] = (2.00000000f * A[((32 * cache_ax0) + cache_ax1)]);
B_local_temp_buffer[((32 * cache_ax0) + cache_ax1)] = (A[((32 * cache_ax0) + cache_ax1)] * 2.00000000f);
B[((32 * cache_ax0) + cache_ax1)] = B_local_temp_buffer[((32 * cache_ax0) + cache_ax1)];
};
};
Expand Down Expand Up @@ -1713,7 +1713,7 @@ void test_cache_read3(const float* __restrict__ A, float* __restrict__ C)
};
for (int32_t i = 0; i < 32; i += 1) {
for (int32_t j = 0; j < 32; j += 1) {
B[((32 * i) + j)] = (2.00000000f * A_local_temp_buffer[((64 * i) + j)]);
B[((32 * i) + j)] = (A_local_temp_buffer[((64 * i) + j)] * 2.00000000f);
};
__syncthreads();
};
Expand All @@ -1725,7 +1725,7 @@ void test_cache_read3(const float* __restrict__ A, float* __restrict__ C)
for (int32_t i = 0; i < 16; i += 1) {
__syncthreads();
for (int32_t j = 0; j < 16; j += 1) {
C[((16 * i) + j)] = (1.00000000f + B_local_temp_buffer[((32 * i) + j)]);
C[((16 * i) + j)] = (B_local_temp_buffer[((32 * i) + j)] + 1.00000000f);
};
};
}
Expand Down Expand Up @@ -1794,7 +1794,7 @@ void test_cache_write3(const float* __restrict__ A, float* __restrict__ C)
float* B = _B_temp_buffer;
for (int32_t i = 0; i < 64; i += 1) {
for (int32_t j = 0; j < 32; j += 1) {
B_local_temp_buffer[((32 * i) + j)] = (2.00000000f * A[((32 * i) + j)]);
B_local_temp_buffer[((32 * i) + j)] = (A[((32 * i) + j)] * 2.00000000f);
};
};
for (int32_t cache_ax0 = 0; cache_ax0 < 64; cache_ax0 += 1) {
Expand All @@ -1805,7 +1805,7 @@ void test_cache_write3(const float* __restrict__ A, float* __restrict__ C)
__syncthreads();
for (int32_t i = 0; i < 64; i += 1) {
for (int32_t j = 0; j < 32; j += 1) {
C_local_temp_buffer[((32 * i) + j)] = (1.00000000f + B[((32 * i) + j)]);
C_local_temp_buffer[((32 * i) + j)] = (B[((32 * i) + j)] + 1.00000000f);
};
};
__syncthreads();
Expand Down Expand Up @@ -1878,7 +1878,7 @@ void test_sync_threads(const float* __restrict__ A, float* __restrict__ C)
float* B = _B_temp_buffer;
for (int32_t i = 0; i < 64; i += 1) {
for (int32_t j = 0; j < 32; j += 1) {
B_local_temp_buffer[((32 * i) + j)] = (2.00000000f * A[((32 * i) + j)]);
B_local_temp_buffer[((32 * i) + j)] = (A[((32 * i) + j)] * 2.00000000f);
};
};
for (int32_t cache_ax0 = 0; cache_ax0 < 64; cache_ax0 += 1) {
Expand All @@ -1889,7 +1889,7 @@ void test_sync_threads(const float* __restrict__ A, float* __restrict__ C)
};
for (int32_t i = 0; i < 64; i += 1) {
for (int32_t j = 0; j < 32; j += 1) {
C_local_temp_buffer[((32 * i) + j)] = (1.00000000f + B[((32 * i) + j)]);
C_local_temp_buffer[((32 * i) + j)] = (B[((32 * i) + j)] + 1.00000000f);
};
};
for (int32_t cache_ax0_0 = 0; cache_ax0_0 < 64; cache_ax0_0 += 1) {
Expand Down Expand Up @@ -2716,7 +2716,7 @@ void test_compute_inline1(void* _args, int32_t num_args)
for (int32_t i = 0; i < 32; i += 1) {
for (int32_t j = 0; j < 32; j += 1) {
for (int32_t k = 0; k < 32; k += 1) {
C[((1024 * i) + ((32 * j) + k))] = fma(2.00000000f, A[((32 * i) + ((1024 * j) + k))], 2.00000000f);
C[((1024 * i) + ((32 * j) + k))] = ((A[((32 * i) + ((1024 * j) + k))] + 1.00000000f) * 2.00000000f);
};
};
};
Expand Down Expand Up @@ -2790,7 +2790,7 @@ void test_compute_inline2(void* _args, int32_t num_args)
for (int32_t i = 0; i < 32; i += 1) {
for (int32_t j = 0; j < 32; j += 1) {
for (int32_t k = 0; k < 32; k += 1) {
C[((1024 * i) + ((32 * j) + k))] = fma(2.00000000f, A[((1024 * i) + ((32 * j) + k))], 2.00000000f);
C[((1024 * i) + ((32 * j) + k))] = ((A[((1024 * i) + ((32 * j) + k))] + 1.00000000f) * 2.00000000f);
};
};
};
Expand Down Expand Up @@ -2855,7 +2855,7 @@ void test_compute_inline3(const float* __restrict__ A, float* __restrict__ C)
for (int32_t i = 0; i < 32; i += 1) {
for (int32_t j = 0; j < 32; j += 1) {
for (int32_t k = 0; k < 32; k += 1) {
C[((1024 * i) + ((32 * j) + k))] = (2.00000000f + (2.00000000f * A[((32 * i) + ((1024 * j) + k))]));
C[((1024 * i) + ((32 * j) + k))] = ((A[((32 * i) + ((1024 * j) + k))] + 1.00000000f) * 2.00000000f);
};
};
};
Expand Down Expand Up @@ -2917,7 +2917,7 @@ void test_compute_inline4(const float* __restrict__ A, float* __restrict__ C)
for (int32_t i = 0; i < 32; i += 1) {
for (int32_t j = 0; j < 32; j += 1) {
for (int32_t k = 0; k < 32; k += 1) {
C[((1024 * i) + ((32 * j) + k))] = (2.00000000f + (2.00000000f * A[((1024 * i) + ((32 * j) + k))]));
C[((1024 * i) + ((32 * j) + k))] = ((A[((1024 * i) + ((32 * j) + k))] + 1.00000000f) * 2.00000000f);
};
};
};
Expand Down Expand Up @@ -2979,7 +2979,7 @@ void test_compute_inline1(void* _args, int32_t num_args)
float* C = ((float*)(_C->memory));
for (int32_t i = 0; i < 32; i += 1) {
for (int32_t j = 0; j < 64; j += 1) {
C[((32 * j) + i)] = fma(2.00000000f, A[((64 * i) + j)], 2.00000000f);
C[((32 * j) + i)] = (2.00000000f * (1.00000000f + A[((64 * i) + j)]));
};
};
cinn_buffer_free((void*)(0), _B);
Expand Down Expand Up @@ -3047,7 +3047,7 @@ void test_compute_inline1(void* _args, int32_t num_args)
for (int32_t i = 0; i < 32; i += 1) {
for (int32_t j = 0; j < 32; j += 1) {
for (int32_t k = 0; k < 32; k += 1) {
C[((32 * i) + ((1024 * j) + k))] = fma(2.00000000f, A[((1024 * i) + ((32 * j) + k))], 2.00000000f);
C[((32 * i) + ((1024 * j) + k))] = (2.00000000f * (1.00000000f + A[((1024 * i) + ((32 * j) + k))]));
};
};
};
Expand Down Expand Up @@ -3125,7 +3125,7 @@ void test_copytransform1(void* _args, int32_t num_args)
for (int32_t j = 0; j < 8; j += 1) {
for (int32_t j_0 = 0; j_0 < 4; j_0 += 1) {
for (int32_t k = 0; k < 32; k += 1) {
B[((8192 * i) + ((1024 * i_0) + ((128 * j) + ((32 * j_0) + k))))] = (1.00000000f + A[((8192 * i) + ((1024 * i_0) + ((128 * j) + ((32 * j_0) + k))))]);
B[((8192 * i) + ((1024 * i_0) + ((128 * j) + ((32 * j_0) + k))))] = (A[((8192 * i) + ((1024 * i_0) + ((128 * j) + ((32 * j_0) + k))))] + 1.00000000f);
};
};
};
Expand All @@ -3136,7 +3136,7 @@ void test_copytransform1(void* _args, int32_t num_args)
for (int32_t j = 0; j < 8; j += 1) {
for (int32_t j_0 = 0; j_0 < 4; j_0 += 1) {
for (int32_t k = 0; k < 32; k += 1) {
C[((8192 * i) + ((1024 * i_0) + ((128 * j) + ((32 * j_0) + k))))] = (2.00000000f * B[((256 * i) + ((32 * i_0) + ((4096 * j) + ((1024 * j_0) + k))))]);
C[((8192 * i) + ((1024 * i_0) + ((128 * j) + ((32 * j_0) + k))))] = (B[((256 * i) + ((32 * i_0) + ((4096 * j) + ((1024 * j_0) + k))))] * 2.00000000f);
};
};
};
Expand Down Expand Up @@ -3214,7 +3214,7 @@ void test_copytransform2(void* _args, int32_t num_args)
for (int32_t i_0 = 0; i_0 < 8; i_0 += 1) {
for (int32_t j = 0; j < 64; j += 1) {
for (int32_t k = 0; k < 128; k += 1) {
B[((65536 * i) + ((8192 * i_0) + ((128 * j) + k)))] = (1.00000000f + A[((65536 * i) + ((8192 * i_0) + ((128 * j) + k)))]);
B[((65536 * i) + ((8192 * i_0) + ((128 * j) + k)))] = (A[((65536 * i) + ((8192 * i_0) + ((128 * j) + k)))] + 1.00000000f);
};
};
};
Expand All @@ -3224,7 +3224,7 @@ void test_copytransform2(void* _args, int32_t num_args)
for (int32_t j = 0; j < 8; j += 1) {
for (int32_t j_0 = 0; j_0 < 4; j_0 += 1) {
for (int32_t k = 0; k < 128; k += 1) {
C[((32768 * i) + ((4096 * i_0) + ((512 * j) + ((128 * j_0) + k))))] = (2.00000000f * B[((65536 * i) + ((8192 * i_0) + ((512 * j) + ((128 * j_0) + k))))]);
C[((32768 * i) + ((4096 * i_0) + ((512 * j) + ((128 * j_0) + k))))] = (B[((65536 * i) + ((8192 * i_0) + ((512 * j) + ((128 * j_0) + k))))] * 2.00000000f);
};
};
};
Expand Down
3 changes: 3 additions & 0 deletions paddle/cinn/common/cas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ Expr AutoSimplify(
Expr u,
const absl::flat_hash_map<std::string, CasInterval>& var_intervals) {
VLOG(7) << "Begin AutoSimplify: " << u;
if (u.type().is_float()) {
return u;
}
u = detail::ConvertCinnToCAS(u);
absl::flat_hash_map<std::string, CasInterval> s_var_intervals;
for (auto& item : var_intervals) {
Expand Down
3 changes: 0 additions & 3 deletions paddle/cinn/common/cas_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -458,9 +458,6 @@ TEST(CAS, cond) {
TEST(CAS, SimplifyFracOp) {
Expr frac = Expr(1) / Expr(7) / Expr(6) / Expr(5) / Expr(4);
EXPECT_EQ(GetStreamCnt(AutoSimplify(frac)), "0");

Expr frac_f = Expr(20.0f) / Expr(2.0f) / Expr(1.0f) / Expr(5.0f);
EXPECT_EQ(GetStreamCnt(AutoSimplify(frac_f)), "2.00000000f");
}

} // namespace common
Expand Down
6 changes: 3 additions & 3 deletions paddle/cinn/ir/test/tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function func_C (_A, _B, _D)
{
serial for (j, 0, 20)
{
D[i, j] = (1.00000000f + ((2.00000000f * A[i, j]) + (2.00000000f * B[i, j])))
D[i, j] = (((A[i, j] + B[i, j]) * 2.00000000f) + 1.00000000f)
}
}
}
Expand Down Expand Up @@ -117,7 +117,7 @@ void fn(void* _args, int32_t num_args)
for (int32_t i = 0; i < 10; i += 1) {
for (int32_t j = 0; j < 10; j += 1) {
for (int32_t k = 0; k < 100; k += 1) {
B[((1000 * i) + ((100 * j) + k))] = (2.00000000f * A_reshape[((1000 * i) + ((100 * j) + k))]);
B[((1000 * i) + ((100 * j) + k))] = (A_reshape[((1000 * i) + ((100 * j) + k))] * 2.00000000f);
};
};
};
Expand Down Expand Up @@ -175,7 +175,7 @@ void fn(void* _args, int32_t num_args)
for (int32_t i = 0; i < 10; i += 1) {
for (int32_t j = 0; j < 10; j += 1) {
for (int32_t k = 0; k < 100; k += 1) {
B[((1000 * i) + ((100 * j) + k))] = (2.00000000f * A_copied_reshape[((1000 * i) + ((100 * j) + k))]);
B[((1000 * i) + ((100 * j) + k))] = (A_copied_reshape[((1000 * i) + ((100 * j) + k))] * 2.00000000f);
};
};
};
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/lang/lower_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ TEST(lower, basic) {
{
serial for (j, 0, 15)
{
B[i, j] = (1.00000000f + A[i, j])
B[i, j] = (A[i, j] + 1.00000000f)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/optim/cache_read_write_replace_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ function fn (_A, _B, _C1_write_cache)
{
serial for (j, 0, 100)
{
C1_write_cache[i, j] = (3.00000000f + A[i, j])
C1_write_cache[i, j] = (((A[i, j] + 1.00000000f) + 1.00000000f) + 1.00000000f)
}
}
serial for (i, 0, 100)
Expand Down
6 changes: 3 additions & 3 deletions paddle/cinn/optim/ir_simplify_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ TEST(IrSimplify, basic) {
// get (((C[(i * 20)] + 0) + 100) + 24.5)
Simplify(&B);
LOG(INFO) << "simplified: " << B;
auto out = "(124.500000f + C[i, 0])";
auto out = "(((C[i, 0] + 0.00000000f) + 100.000000f) + 24.5000000f)";
EXPECT_EQ(out, utils::GetStreamCnt(B));
}

Expand Down Expand Up @@ -69,7 +69,7 @@ TEST(IrSimplify, basic) {
{
serial for (j, 0, 20)
{
B[i, j] = (125.000000f + (X[i, j] + y[i, 0]))
B[i, j] = ((((((X[i, j] + (y[i, 0] * 1.00000000f)) + (0.00000000f * X[i, j])) + 25.0000000f) + 100.000000f) - 0.00000000f) + 0.00000000f)
}
}
}
Expand Down Expand Up @@ -104,7 +104,7 @@ TEST(IrSimplify, basic) {
{
serial for (j, 0, 20)
{
B[i, j] = ((y[i, 0] / 3.00000000f) + (125.000000f + X[(1000 * i), 0]))
B[i, j] = ((((((X[(1000 * i), 0] + (y[i, 0] / 3.00000000f)) + (0.00000000f * X[i, j])) + 25.0000000f) + 100.000000f) - 0.00000000f) + 0.00000000f)
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions paddle/cinn/optim/optimize_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ TEST(Optimize, Unroll) {
{
serial for (j_outer, 0, 4)
{
C[i, (5 * j_outer)] = (1.00000000f + A[i, (5 * j_outer)])
C[i, (1 + (5 * j_outer))] = (1.00000000f + A[i, (1 + (5 * j_outer))])
C[i, (2 + (5 * j_outer))] = (1.00000000f + A[i, (2 + (5 * j_outer))])
C[i, (3 + (5 * j_outer))] = (1.00000000f + A[i, (3 + (5 * j_outer))])
C[i, (4 + (5 * j_outer))] = (1.00000000f + A[i, (4 + (5 * j_outer))])
C[i, (5 * j_outer)] = (A[i, (5 * j_outer)] + 1.00000000f)
C[i, (1 + (5 * j_outer))] = (A[i, (1 + (5 * j_outer))] + 1.00000000f)
C[i, (2 + (5 * j_outer))] = (A[i, (2 + (5 * j_outer))] + 1.00000000f)
C[i, (3 + (5 * j_outer))] = (A[i, (3 + (5 * j_outer))] + 1.00000000f)
C[i, (4 + (5 * j_outer))] = (A[i, (4 + (5 * j_outer))] + 1.00000000f)
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions paddle/cinn/poly/schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ TEST(CreateStages, compute_at) {
{
serial for (j, 0, 100)
{
B[i, j] = (1.00000000f + A[i, j])
B[i, j] = (A[i, j] + 1.00000000f)
serial for (k, 0, 100)
{
C[i, j, k] = (B[i, j] * B[j, k])
Expand Down Expand Up @@ -99,21 +99,21 @@ TEST(CreateStages, buffer_bind_to_multiple_tensors_schedule) {
{
serial for (j, 0, 100)
{
B[i, j] = (1.00000000f + A[i, j])
B[i, j] = (A[i, j] + 1.00000000f)
}
}
serial for (i, 0, 100)
{
serial for (j, 0, 100)
{
C[i, j] = (1.00000000f + A[i, j])
C[i, j] = (A[i, j] + 1.00000000f)
}
}
serial for (i, 0, 100)
{
serial for (j, 0, 100)
{
D[i, j] = (1.00000000f + A[i, j])
D[i, j] = (A[i, j] + 1.00000000f)
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions paddle/cinn/poly/stage_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ function fn (_A, _A1, _B)
}
serial for (j, 0, 32)
{
B[((16 * i_outer) + i_inner), j] = (A1[((16 * i_outer) + i_inner), j] + (A1[(1 + ((16 * i_outer) + i_inner)), j] + A1[(2 + ((16 * i_outer) + i_inner)), j]))
B[((16 * i_outer) + i_inner), j] = ((A1[((16 * i_outer) + i_inner), j] + A1[(1 + ((16 * i_outer) + i_inner)), j]) + A1[(2 + ((16 * i_outer) + i_inner)), j])
}
}
}
Expand Down Expand Up @@ -431,7 +431,7 @@ function fn (_A, _C)
{
serial for (j, 0, 200)
{
C[i, j] = (6.00000000f + (2.00000000f * A[i, j]))
C[i, j] = ((((A[i, j] + 1.00000000f) + 1.00000000f) + 1.00000000f) * 2.00000000f)
}
}
}
Expand Down Expand Up @@ -475,21 +475,21 @@ function fn (_A, _C, _C1, _C2)
{
serial for (j, 0, 200)
{
C2[i, j] = (6.00000000f + (2.00000000f * A[i, j]))
C2[i, j] = ((((A[i, j] + 1.00000000f) + 1.00000000f) + 1.00000000f) * 2.00000000f)
}
}
serial for (i, 0, 100)
{
serial for (j, 0, 200)
{
C1[i, j] = (4.00000000f + (2.00000000f * A[i, j]))
C1[i, j] = (((A[i, j] + 1.00000000f) + 1.00000000f) * 2.00000000f)
}
}
serial for (i, 0, 100)
{
serial for (j, 0, 200)
{
C[i, j] = (2.00000000f + (2.00000000f * A[i, j]))
C[i, j] = ((A[i, j] + 1.00000000f) * 2.00000000f)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion test/cinn/ir/test_llir_schedule_fuse_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def elementwise_fuse_assign_loop(
i_j_k_fused % 128,
],
)
Y[i1, j1, k1] = 2.0 * X[i1, j1, k1]
Y[i1, j1, k1] = X[i1, j1, k1] * 2.0

assert str(origin.elementwise_fuse_assign_loop) == str(
expected.elementwise_fuse_assign_loop
Expand Down

0 comments on commit c00f1de

Please sign in to comment.