Skip to content

Commit

Permalink
fix se_e3 tabulate op (#2552)
Browse files Browse the repository at this point in the history
Fix #2250.

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored May 22, 2023
1 parent 44b793f commit 450455f
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 110 deletions.
28 changes: 1 addition & 27 deletions source/lib/src/cuda/tabulate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -337,20 +337,11 @@ __global__ void tabulate_fusion_se_t_fifth_order_polynomial(

FPTYPE sum = (FPTYPE)0.;
for (int ii = 0; ii < nnei_i; ii++) {
FPTYPE ago = __shfl_sync(
0xffffffff,
em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0);
int breakpoint = nnei_j - 1;
bool unloop = false;
FPTYPE var[6];
int mark_table_idx = -1;
for (int jj = 0; jj < nnei_j; jj++) {
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE tmp = xx;
if (xx == ago) {
unloop = true;
breakpoint = jj;
}
int table_idx = 0;
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
if (table_idx != mark_table_idx) {
Expand All @@ -363,9 +354,8 @@ __global__ void tabulate_fusion_se_t_fifth_order_polynomial(
(var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) *
xx;

sum += (nnei_j - breakpoint) * tmp * res;
sum += tmp * res;
mark_table_idx = table_idx;
if (unloop) break;
}
}
out[block_idx * last_layer_size + thread_idx] = sum;
Expand Down Expand Up @@ -399,16 +389,9 @@ __global__ void tabulate_fusion_se_t_grad_fifth_order_polynomial(
__syncthreads();

for (int ii = 0; ii < nnei_i; ii++) {
FPTYPE ago = __shfl_sync(
0xffffffff,
em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0);
bool unloop = false;
for (int jj = warp_idx; jj < nnei_j; jj += KTILE) {
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE tmp = xx;
if (ago == xx) {
unloop = true;
}
int table_idx = 0;
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
FPTYPE sum = (FPTYPE)0.;
Expand Down Expand Up @@ -438,7 +421,6 @@ __global__ void tabulate_fusion_se_t_grad_fifth_order_polynomial(
dy_dem[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = sum;
dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = Csub;
}
if (unloop) break;
}
}
}
Expand All @@ -464,10 +446,6 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial(

FPTYPE sum = (FPTYPE)0.;
for (int ii = 0; ii < nnei_i; ii++) {
FPTYPE ago = __shfl_sync(
0xffffffff,
em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0);
bool unloop = false;
int mark_table_idx = -1;
for (int jj = 0; ii < nnei_j; jj++) {
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
Expand All @@ -476,9 +454,6 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial(
dz_dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE dz_em = dz_dy_dem[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE var[6];
if (ago == xx) {
unloop = true;
}

int table_idx = 0;
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
Expand All @@ -498,7 +473,6 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial(

sum += (tmp * res_grad * dz_xx + dz_em * res);
mark_table_idx = table_idx;
if (unloop) break;
}
}
dz_dy[block_idx * last_layer_size + thread_idx] = sum;
Expand Down
21 changes: 1 addition & 20 deletions source/lib/src/rocm/tabulate.hip.cu
Original file line number Diff line number Diff line change
Expand Up @@ -312,17 +312,9 @@ __global__ void tabulate_fusion_se_t_fifth_order_polynomial(

FPTYPE sum = (FPTYPE)0.;
for (int ii = 0; ii < nnei_i; ii++) {
FPTYPE ago =
__shfl(em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0);
int breakpoint = nnei_j - 1;
bool unloop = false;
for (int jj = 0; jj < nnei_j; jj++) {
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE tmp = xx;
if (xx == ago) {
unloop = true;
breakpoint = jj;
}
int table_idx = 0;
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
FPTYPE var[6];
Expand All @@ -338,8 +330,7 @@ __global__ void tabulate_fusion_se_t_fifth_order_polynomial(
(var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) *
xx;

sum += (nnei_j - breakpoint) * tmp * res;
if (unloop) break;
sum += tmp * res;
}
}
out[block_idx * last_layer_size + thread_idx] = sum;
Expand Down Expand Up @@ -375,13 +366,9 @@ __global__ void tabulate_fusion_se_t_grad_fifth_order_polynomial(
for (int ii = 0; ii < nnei_i; ii++) {
FPTYPE ago =
__shfl(em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0);
bool unloop = false;
for (int jj = warp_idx; jj < nnei_j; jj += KTILE) {
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE tmp = xx;
if (ago == xx) {
unloop = true;
}
int table_idx = 0;
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
FPTYPE sum = (FPTYPE)0.;
Expand Down Expand Up @@ -417,7 +404,6 @@ __global__ void tabulate_fusion_se_t_grad_fifth_order_polynomial(
dy_dem[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = sum;
dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = Csub;
}
if (unloop) break;
}
}
}
Expand Down Expand Up @@ -445,17 +431,13 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial(
for (int ii = 0; ii < nnei_i; ii++) {
FPTYPE ago =
__shfl(em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0);
bool unloop = false;
for (int jj = 0; ii < nnei_j; jj++) {
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE tmp = xx;
FPTYPE dz_xx =
dz_dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE dz_em = dz_dy_dem[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE var[6];
if (ago == xx) {
unloop = true;
}

int table_idx = 0;
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
Expand All @@ -478,7 +460,6 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial(
xx;

sum += (tmp * res_grad * dz_xx + dz_em * res);
if (unloop) break;
}
}
dz_dy[block_idx * last_layer_size + thread_idx] = sum;
Expand Down
50 changes: 9 additions & 41 deletions source/lib/src/tabulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,14 +322,10 @@ void deepmd::tabulate_fusion_se_t_cpu(FPTYPE* out,
#pragma omp parallel for
for (int ii = 0; ii < nloc; ii++) {
for (int jj = 0; jj < nnei_i; jj++) {
FPTYPE ago = em_x[ii * nnei_i * nnei_j + jj * nnei_j + nnei_j - 1];
bool unloop = false;
// unloop not work as em_x is not sorted
for (int kk = 0; kk < nnei_j; kk++) {
FPTYPE xx = em_x[ii * nnei_i * nnei_j + jj * nnei_j + kk];
FPTYPE ll = xx;
if (ago == xx) {
unloop = true;
}
int table_idx = 0;
locate_xx_se_t(lower, upper, -_max, _max, stride0, stride1, xx,
table_idx);
Expand All @@ -342,13 +338,8 @@ void deepmd::tabulate_fusion_se_t_cpu(FPTYPE* out,
FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * mm + 5];
FPTYPE var =
a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx;
if (unloop) {
out[ii * last_layer_size + mm] += (nnei_j - kk) * var * ll;
} else {
out[ii * last_layer_size + mm] += var * ll;
}
out[ii * last_layer_size + mm] += var * ll;
}
if (unloop) break;
}
}
}
Expand Down Expand Up @@ -380,15 +371,10 @@ void deepmd::tabulate_fusion_se_t_grad_cpu(FPTYPE* dy_dem_x,
FPTYPE ll = (FPTYPE)0.;
FPTYPE rr = (FPTYPE)0.;
for (int jj = 0; jj < nnei_i; jj++) {
FPTYPE ago = em_x[ii * nnei_i * nnei_j + jj * nnei_j + nnei_j - 1];
bool unloop = false;
for (int kk = 0; kk < nnei_j; kk++) {
// construct the dy/dx
FPTYPE xx = em_x[ii * nnei_i * nnei_j + jj * nnei_j + kk];
ll = xx;
if (ago == xx) {
unloop = true;
}
int table_idx = 0;
locate_xx_se_t(lower, upper, -_max, _max, stride0, stride1, xx,
table_idx);
Expand All @@ -404,27 +390,15 @@ void deepmd::tabulate_fusion_se_t_grad_cpu(FPTYPE* dy_dem_x,
FPTYPE res =
a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx;

if (unloop) {
grad += (a1 + ((FPTYPE)2. * a2 +
((FPTYPE)3. * a3 +
((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) *
xx) *
xx) *
ll * rr * (nnei_j - kk);
dy_dem[ii * nnei_i * nnei_j + jj * nnei_j + kk] +=
res * rr * (nnei_j - kk);
} else {
grad += (a1 + ((FPTYPE)2. * a2 +
((FPTYPE)3. * a3 +
((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) *
xx) *
xx) *
ll * rr;
dy_dem[ii * nnei_i * nnei_j + jj * nnei_j + kk] += res * rr;
}
grad += (a1 + ((FPTYPE)2. * a2 +
((FPTYPE)3. * a3 +
((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) *
xx) *
xx) *
ll * rr;
dy_dem[ii * nnei_i * nnei_j + jj * nnei_j + kk] += res * rr;
}
dy_dem_x[ii * nnei_i * nnei_j + jj * nnei_j + kk] = grad;
if (unloop) break;
}
}
}
Expand Down Expand Up @@ -453,17 +427,12 @@ void deepmd::tabulate_fusion_se_t_grad_grad_cpu(FPTYPE* dz_dy,
#pragma omp parallel for
for (int ii = 0; ii < nloc; ii++) {
for (int jj = 0; jj < nnei_i; jj++) {
FPTYPE ago = em_x[ii * nnei_i * nnei_j + jj * nnei_j + nnei_j - 1];
bool unloop = false;
for (int kk = 0; kk < nnei_j; kk++) {
FPTYPE xx = em_x[ii * nnei_i * nnei_j + jj * nnei_j + kk];
FPTYPE tmp = xx;
FPTYPE dz_em = dz_dy_dem[ii * nnei_i * nnei_j + jj * nnei_j + kk];
FPTYPE dz_xx = dz_dy_dem_x[ii * nnei_i * nnei_j + jj * nnei_j + kk];

if (ago == xx) {
unloop = true;
}
int table_idx = 0;
locate_xx_se_t(lower, upper, -_max, _max, stride0, stride1, xx,
table_idx);
Expand All @@ -486,7 +455,6 @@ void deepmd::tabulate_fusion_se_t_grad_grad_cpu(FPTYPE* dz_dy,
dz_dy[ii * last_layer_size + mm] +=
var * dz_em + dz_xx * var_grad * tmp;
}
if (unloop) break;
}
}
}
Expand Down
66 changes: 44 additions & 22 deletions source/tests/test_model_compression_se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,28 @@ def _init_models():
INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models()


def tearDownModule():
_file_delete(INPUT)
_file_delete(FROZEN_MODEL)
_file_delete(COMPRESSED_MODEL)
_file_delete("out.json")
_file_delete("compress.json")
_file_delete("checkpoint")
_file_delete("model.ckpt.meta")
_file_delete("model.ckpt.index")
_file_delete("model.ckpt.data-00000-of-00001")
_file_delete("model.ckpt-100.meta")
_file_delete("model.ckpt-100.index")
_file_delete("model.ckpt-100.data-00000-of-00001")
_file_delete("model-compression/checkpoint")
_file_delete("model-compression/model.ckpt.meta")
_file_delete("model-compression/model.ckpt.index")
_file_delete("model-compression/model.ckpt.data-00000-of-00001")
_file_delete("model-compression")
_file_delete("input_v2_compat.json")
_file_delete("lcurve.out")


class TestDeepPotAPBC(unittest.TestCase):
@classmethod
def setUpClass(self):
Expand Down Expand Up @@ -444,28 +466,6 @@ def setUpClass(self):
self.atype = [0, 1, 1, 0, 1, 1]
self.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0])

@classmethod
def tearDownClass(self):
_file_delete(INPUT)
_file_delete(FROZEN_MODEL)
_file_delete(COMPRESSED_MODEL)
_file_delete("out.json")
_file_delete("compress.json")
_file_delete("checkpoint")
_file_delete("model.ckpt.meta")
_file_delete("model.ckpt.index")
_file_delete("model.ckpt.data-00000-of-00001")
_file_delete("model.ckpt-100.meta")
_file_delete("model.ckpt-100.index")
_file_delete("model.ckpt-100.data-00000-of-00001")
_file_delete("model-compression/checkpoint")
_file_delete("model-compression/model.ckpt.meta")
_file_delete("model-compression/model.ckpt.index")
_file_delete("model-compression/model.ckpt.data-00000-of-00001")
_file_delete("model-compression")
_file_delete("input_v2_compat.json")
_file_delete("lcurve.out")

def test_attrs(self):
self.assertEqual(self.dp_original.get_ntypes(), 2)
self.assertAlmostEqual(self.dp_original.get_rcut(), 6.0, places=default_places)
Expand Down Expand Up @@ -558,3 +558,25 @@ def test_2frame_atm(self):
np.testing.assert_almost_equal(av0, av1, default_places)
np.testing.assert_almost_equal(ee0, ee1, default_places)
np.testing.assert_almost_equal(vv0, vv1, default_places)


class TestDeepPotAPBC2(TestDeepPotAPBC):
@classmethod
def setUpClass(self):
self.dp_original = DeepPot(FROZEN_MODEL)
self.dp_compressed = DeepPot(COMPRESSED_MODEL)
self.coords = np.array(
[
0.0,
0.0,
0.0,
2.0,
0.0,
0.0,
0.0,
2.0,
0.0,
]
)
self.atype = [0, 0, 0]
self.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0])

0 comments on commit 450455f

Please sign in to comment.