From 450455f4e6200aebad427ba2bc98b380ec38928b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 22 May 2023 10:46:45 -0400 Subject: [PATCH] fix se_e3 tabulate op (#2552) Fix #2250. --------- Signed-off-by: Jinzhe Zeng --- source/lib/src/cuda/tabulate.cu | 28 +-------- source/lib/src/rocm/tabulate.hip.cu | 21 +------ source/lib/src/tabulate.cc | 50 +++------------- source/tests/test_model_compression_se_t.py | 66 ++++++++++++++------- 4 files changed, 55 insertions(+), 110 deletions(-) diff --git a/source/lib/src/cuda/tabulate.cu b/source/lib/src/cuda/tabulate.cu index 2e8c24cf99..5c3d360c25 100644 --- a/source/lib/src/cuda/tabulate.cu +++ b/source/lib/src/cuda/tabulate.cu @@ -337,20 +337,11 @@ __global__ void tabulate_fusion_se_t_fifth_order_polynomial( FPTYPE sum = (FPTYPE)0.; for (int ii = 0; ii < nnei_i; ii++) { - FPTYPE ago = __shfl_sync( - 0xffffffff, - em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0); - int breakpoint = nnei_j - 1; - bool unloop = false; FPTYPE var[6]; int mark_table_idx = -1; for (int jj = 0; jj < nnei_j; jj++) { FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; FPTYPE tmp = xx; - if (xx == ago) { - unloop = true; - breakpoint = jj; - } int table_idx = 0; locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1); if (table_idx != mark_table_idx) { @@ -363,9 +354,8 @@ __global__ void tabulate_fusion_se_t_fifth_order_polynomial( (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * xx; - sum += (nnei_j - breakpoint) * tmp * res; + sum += tmp * res; mark_table_idx = table_idx; - if (unloop) break; } } out[block_idx * last_layer_size + thread_idx] = sum; @@ -399,16 +389,9 @@ __global__ void tabulate_fusion_se_t_grad_fifth_order_polynomial( __syncthreads(); for (int ii = 0; ii < nnei_i; ii++) { - FPTYPE ago = __shfl_sync( - 0xffffffff, - em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0); - bool unloop = false; for (int jj = warp_idx; jj < nnei_j; jj += KTILE) { FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; FPTYPE tmp = xx; - if (ago == xx) { - unloop = true; - } int table_idx = 0; locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1); FPTYPE sum = (FPTYPE)0.; @@ -438,7 +421,6 @@ __global__ void tabulate_fusion_se_t_grad_fifth_order_polynomial( dy_dem[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = sum; dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = Csub; } - if (unloop) break; } } } @@ -464,10 +446,6 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial( FPTYPE sum = (FPTYPE)0.; for (int ii = 0; ii < nnei_i; ii++) { - FPTYPE ago = __shfl_sync( - 0xffffffff, - em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0); - bool unloop = false; int mark_table_idx = -1; for (int jj = 0; ii < nnei_j; jj++) { FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; @@ -476,9 +454,6 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial( dz_dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; FPTYPE dz_em = dz_dy_dem[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; FPTYPE var[6]; - if (ago == xx) { - unloop = true; - } int table_idx = 0; locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1); @@ -498,7 +473,6 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial( sum += (tmp * res_grad * dz_xx + dz_em * res); mark_table_idx = table_idx; - if (unloop) break; } } dz_dy[block_idx * last_layer_size + thread_idx] = sum; diff --git a/source/lib/src/rocm/tabulate.hip.cu b/source/lib/src/rocm/tabulate.hip.cu index 1ac1bca8f8..b356cb6f3e 100644 --- a/source/lib/src/rocm/tabulate.hip.cu +++ b/source/lib/src/rocm/tabulate.hip.cu @@ -312,17 +312,9 @@ __global__ void tabulate_fusion_se_t_fifth_order_polynomial( FPTYPE sum = (FPTYPE)0.; for (int ii = 0; ii < nnei_i; ii++) { - FPTYPE ago = - __shfl(em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0); - int breakpoint = nnei_j - 1; - bool unloop = false; for (int jj = 0; jj < nnei_j; jj++) { FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; FPTYPE tmp = xx; - if (xx == ago) { - unloop = true; - breakpoint = jj; - } int table_idx = 0; locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1); FPTYPE var[6]; @@ -338,8 +330,7 @@ __global__ void tabulate_fusion_se_t_fifth_order_polynomial( (var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) * xx; - sum += (nnei_j - breakpoint) * tmp * res; - if (unloop) break; + sum += tmp * res; } } out[block_idx * last_layer_size + thread_idx] = sum; @@ -375,13 +366,9 @@ __global__ void tabulate_fusion_se_t_grad_fifth_order_polynomial( for (int ii = 0; ii < nnei_i; ii++) { FPTYPE ago = __shfl(em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0); - bool unloop = false; for (int jj = warp_idx; jj < nnei_j; jj += KTILE) { FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; FPTYPE tmp = xx; - if (ago == xx) { - unloop = true; - } int table_idx = 0; locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1); FPTYPE sum = (FPTYPE)0.; @@ -417,7 +404,6 @@ __global__ void tabulate_fusion_se_t_grad_fifth_order_polynomial( dy_dem[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = sum; dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = Csub; } - if (unloop) break; } } } @@ -445,7 +431,6 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial( for (int ii = 0; ii < nnei_i; ii++) { FPTYPE ago = __shfl(em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0); - bool unloop = false; for (int jj = 0; ii < nnei_j; jj++) { FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; FPTYPE tmp = xx; @@ -453,9 +438,6 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial( dz_dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; FPTYPE dz_em = dz_dy_dem[block_idx * nnei_i * nnei_j + ii * nnei_j + jj]; FPTYPE var[6]; - if (ago == xx) { - unloop = true; - } int table_idx = 0; locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1); @@ -478,7 +460,6 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial( xx; sum += (tmp * res_grad * dz_xx + dz_em * res); - if (unloop) break; } } dz_dy[block_idx * last_layer_size + thread_idx] = sum; diff --git a/source/lib/src/tabulate.cc b/source/lib/src/tabulate.cc index dcc17e07e9..67921d4cc7 100644 --- a/source/lib/src/tabulate.cc +++ b/source/lib/src/tabulate.cc @@ -322,14 +322,10 @@ void deepmd::tabulate_fusion_se_t_cpu(FPTYPE* out, #pragma omp parallel for for (int ii = 0; ii < nloc; ii++) { for (int jj = 0; jj < nnei_i; jj++) { - FPTYPE ago = em_x[ii * nnei_i * nnei_j + jj * nnei_j + nnei_j - 1]; - bool unloop = false; + // unloop not work as em_x is not sorted for (int kk = 0; kk < nnei_j; kk++) { FPTYPE xx = em_x[ii * nnei_i * nnei_j + jj * nnei_j + kk]; FPTYPE ll = xx; - if (ago == xx) { - unloop = true; - } int table_idx = 0; locate_xx_se_t(lower, upper, -_max, _max, stride0, stride1, xx, table_idx); @@ -342,13 +338,8 @@ void deepmd::tabulate_fusion_se_t_cpu(FPTYPE* out, FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * mm + 5]; FPTYPE var = a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx; - if (unloop) { - out[ii * last_layer_size + mm] += (nnei_j - kk) * var * ll; - } else { - out[ii * last_layer_size + mm] += var * ll; - } + out[ii * last_layer_size + mm] += var * ll; } - if (unloop) break; } } } @@ -380,15 +371,10 @@ void deepmd::tabulate_fusion_se_t_grad_cpu(FPTYPE* dy_dem_x, FPTYPE ll = (FPTYPE)0.; FPTYPE rr = (FPTYPE)0.; for (int jj = 0; jj < nnei_i; jj++) { - FPTYPE ago = em_x[ii * nnei_i * nnei_j + jj * nnei_j + nnei_j - 1]; - bool unloop = false; for (int kk = 0; kk < nnei_j; kk++) { // construct the dy/dx FPTYPE xx = em_x[ii * nnei_i * nnei_j + jj * nnei_j + kk]; ll = xx; - if (ago == xx) { - unloop = true; - } int table_idx = 0; locate_xx_se_t(lower, upper, -_max, _max, stride0, stride1, xx, table_idx); @@ -404,27 +390,15 @@ void deepmd::tabulate_fusion_se_t_grad_cpu(FPTYPE* dy_dem_x, FPTYPE res = a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx; - if (unloop) { - grad += (a1 + ((FPTYPE)2. * a2 + - ((FPTYPE)3. * a3 + - ((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) * - xx) * - xx) * - ll * rr * (nnei_j - kk); - dy_dem[ii * nnei_i * nnei_j + jj * nnei_j + kk] += - res * rr * (nnei_j - kk); - } else { - grad += (a1 + ((FPTYPE)2. * a2 + - ((FPTYPE)3. * a3 + - ((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) * - xx) * - xx) * - ll * rr; - dy_dem[ii * nnei_i * nnei_j + jj * nnei_j + kk] += res * rr; - } + grad += (a1 + ((FPTYPE)2. * a2 + + ((FPTYPE)3. * a3 + + ((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) * + xx) * + xx) * + ll * rr; + dy_dem[ii * nnei_i * nnei_j + jj * nnei_j + kk] += res * rr; } dy_dem_x[ii * nnei_i * nnei_j + jj * nnei_j + kk] = grad; - if (unloop) break; } } } @@ -453,17 +427,12 @@ void deepmd::tabulate_fusion_se_t_grad_grad_cpu(FPTYPE* dz_dy, #pragma omp parallel for for (int ii = 0; ii < nloc; ii++) { for (int jj = 0; jj < nnei_i; jj++) { - FPTYPE ago = em_x[ii * nnei_i * nnei_j + jj * nnei_j + nnei_j - 1]; - bool unloop = false; for (int kk = 0; kk < nnei_j; kk++) { FPTYPE xx = em_x[ii * nnei_i * nnei_j + jj * nnei_j + kk]; FPTYPE tmp = xx; FPTYPE dz_em = dz_dy_dem[ii * nnei_i * nnei_j + jj * nnei_j + kk]; FPTYPE dz_xx = dz_dy_dem_x[ii * nnei_i * nnei_j + jj * nnei_j + kk]; - if (ago == xx) { - unloop = true; - } int table_idx = 0; locate_xx_se_t(lower, upper, -_max, _max, stride0, stride1, xx, table_idx); @@ -486,7 +455,6 @@ void deepmd::tabulate_fusion_se_t_grad_grad_cpu(FPTYPE* dz_dy, dz_dy[ii * last_layer_size + mm] += var * dz_em + dz_xx * var_grad * tmp; } - if (unloop) break; } } } diff --git a/source/tests/test_model_compression_se_t.py b/source/tests/test_model_compression_se_t.py index 6d93067c6c..0e33430158 100644 --- a/source/tests/test_model_compression_se_t.py +++ b/source/tests/test_model_compression_se_t.py @@ -74,6 +74,28 @@ def _init_models(): INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() +def tearDownModule(): + _file_delete(INPUT) + _file_delete(FROZEN_MODEL) + _file_delete(COMPRESSED_MODEL) + _file_delete("out.json") + _file_delete("compress.json") + _file_delete("checkpoint") + _file_delete("model.ckpt.meta") + _file_delete("model.ckpt.index") + _file_delete("model.ckpt.data-00000-of-00001") + _file_delete("model.ckpt-100.meta") + _file_delete("model.ckpt-100.index") + _file_delete("model.ckpt-100.data-00000-of-00001") + _file_delete("model-compression/checkpoint") + _file_delete("model-compression/model.ckpt.meta") + _file_delete("model-compression/model.ckpt.index") + _file_delete("model-compression/model.ckpt.data-00000-of-00001") + _file_delete("model-compression") + _file_delete("input_v2_compat.json") + _file_delete("lcurve.out") + + class TestDeepPotAPBC(unittest.TestCase): @classmethod def setUpClass(self): @@ -444,28 +466,6 @@ def setUpClass(self): self.atype = [0, 1, 1, 0, 1, 1] self.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0]) - @classmethod - def tearDownClass(self): - _file_delete(INPUT) - _file_delete(FROZEN_MODEL) - _file_delete(COMPRESSED_MODEL) - _file_delete("out.json") - _file_delete("compress.json") - _file_delete("checkpoint") - _file_delete("model.ckpt.meta") - _file_delete("model.ckpt.index") - _file_delete("model.ckpt.data-00000-of-00001") - _file_delete("model.ckpt-100.meta") - _file_delete("model.ckpt-100.index") - _file_delete("model.ckpt-100.data-00000-of-00001") - _file_delete("model-compression/checkpoint") - _file_delete("model-compression/model.ckpt.meta") - _file_delete("model-compression/model.ckpt.index") - _file_delete("model-compression/model.ckpt.data-00000-of-00001") - _file_delete("model-compression") - _file_delete("input_v2_compat.json") - _file_delete("lcurve.out") - def test_attrs(self): self.assertEqual(self.dp_original.get_ntypes(), 2) self.assertAlmostEqual(self.dp_original.get_rcut(), 6.0, places=default_places) @@ -558,3 +558,25 @@ def test_2frame_atm(self): np.testing.assert_almost_equal(av0, av1, default_places) np.testing.assert_almost_equal(ee0, ee1, default_places) np.testing.assert_almost_equal(vv0, vv1, default_places) + + +class TestDeepPotAPBC2(TestDeepPotAPBC): + @classmethod + def setUpClass(self): + self.dp_original = DeepPot(FROZEN_MODEL) + self.dp_compressed = DeepPot(COMPRESSED_MODEL) + self.coords = np.array( + [ + 0.0, + 0.0, + 0.0, + 2.0, + 0.0, + 0.0, + 0.0, + 2.0, + 0.0, + ] + ) + self.atype = [0, 0, 0] + self.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0])