Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug of custom op's multiple initialization #812

Merged
merged 13 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions source/api_cc/src/DeepPot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ run_model (ENERGYTYPE & dener,

std::vector<Tensor> output_tensors;
check_status (session->Run(input_tensors,
{"o_energy", "o_force", "o_atom_virial"},
{"o_energy", "o_force", "o_atom_energy", "o_atom_virial"},
{},
&output_tensors));

Tensor output_e = output_tensors[0];
Tensor output_f = output_tensors[1];
Tensor output_av = output_tensors[2];
Tensor output_av = output_tensors[3];

auto oe = output_e.flat <ENERGYTYPE> ();
auto of = output_f.flat <VALUETYPE> ();
Expand Down
31 changes: 16 additions & 15 deletions source/lib/src/cuda/prod_force.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,23 @@ __global__ void force_deriv_wrt_neighbors_a(
const int nnei)
{
// idy -> nnei
const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
const unsigned int idy = blockIdx.y;
const unsigned int idx = blockIdx.x;
const unsigned int idy = blockIdx.y * blockDim.x + threadIdx.x;
const unsigned int idz = threadIdx.y;
const unsigned int idw = threadIdx.z;
const int ndescrpt = nnei * 4;
if (idx >= nloc) {
if (idy >= nnei) {
return;
}
// deriv wrt neighbors
int j_idx = nlist[idx * nnei + idy];
if (j_idx < 0) {
return;
}
atomicAdd(
force + j_idx * 3 + idz,
net_deriv[idx * ndescrpt + idy * 4 + idw] * in_deriv[idx * ndescrpt * 3 + (idy * 4 + idw) * 3 + idz]);
FPTYPE force_tmp = 0.f;
for (int idw = 0; idw < 4; ++idw) {
force_tmp += net_deriv[idx * ndescrpt + idy * 4 + idw] * in_deriv[idx * ndescrpt * 3 + (idy * 4 + idw) * 3 + idz];
}
atomicAdd(force + j_idx * 3 + idz, force_tmp);
}

template<typename FPTYPE>
Expand All @@ -78,11 +79,11 @@ __global__ void force_deriv_wrt_neighbors_r(
const int nnei)
{
// idy -> nnei
const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
const unsigned int idy = blockIdx.y;
const unsigned int idx = blockIdx.x;
const unsigned int idy = blockIdx.y * blockDim.x + threadIdx.x;
const unsigned int idz = threadIdx.y;
const int ndescrpt = nnei * 1;
if (idx >= nloc) {
if (idy >= nnei) {
return;
}
// deriv wrt neighbors
Expand Down Expand Up @@ -116,9 +117,9 @@ void prod_force_a_gpu_cuda(
net_deriv, in_deriv, ndescrpt);

const int LEN = 64;
const int nblock = (nloc + LEN -1) / LEN;
dim3 block_grid(nblock, nnei);
dim3 thread_grid(LEN, 3, 4);
const int nblock = (nnei + LEN - 1) / LEN;
dim3 block_grid(nloc, nblock);
dim3 thread_grid(LEN, 3);
force_deriv_wrt_neighbors_a<<<block_grid, thread_grid>>>(
force,
net_deriv, in_deriv, nlist, nloc, nnei);
Expand All @@ -144,8 +145,8 @@ void prod_force_r_gpu_cuda(
net_deriv, in_deriv, ndescrpt);

const int LEN = 64;
const int nblock = (nloc + LEN -1) / LEN;
dim3 block_grid(nblock, nnei);
const int nblock = (nnei + LEN - 1) / LEN;
dim3 block_grid(nloc, nblock);
dim3 thread_grid(LEN, 3);
force_deriv_wrt_neighbors_r<<<block_grid, thread_grid>>>(
force,
Expand Down
31 changes: 16 additions & 15 deletions source/lib/src/cuda/prod_virial.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,11 @@ __global__ void virial_deriv_wrt_neighbors_a(
// idz = dd0 * 3 + dd1
// dd0 = idz / 3
// dd1 = idz % 3
const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
const unsigned int idy = blockIdx.y;
const unsigned int idx = blockIdx.x;
const unsigned int idy = blockIdx.y * blockDim.x + threadIdx.x;
const unsigned int idz = threadIdx.y;
const unsigned int idw = threadIdx.z;
const int ndescrpt = nnei * 4;
if (idx >= nloc) {
if (idy >= nnei) {
return;
}
int j_idx = nlist[idx * nnei + idy];
Expand All @@ -60,9 +59,11 @@ __global__ void virial_deriv_wrt_neighbors_a(
// atomicAdd(
// virial + idz,
// net_deriv[idx * ndescrpt + idy * 4 + idw] * rij[idx * nnei * 3 + idy * 3 + idz / 3] * in_deriv[idx * ndescrpt * 3 + (idy * 4 + idw) * 3 + idz % 3]);
atomicAdd(
atom_virial + j_idx * 9 + idz,
net_deriv[idx * ndescrpt + idy * 4 + idw] * rij[idx * nnei * 3 + idy * 3 + idz % 3] * in_deriv[idx * ndescrpt * 3 + (idy * 4 + idw) * 3 + idz / 3]);
FPTYPE virial_tmp = 0.f;
for (int idw = 0; idw < 4; ++idw) {
virial_tmp += net_deriv[idx * ndescrpt + idy * 4 + idw] * rij[idx * nnei * 3 + idy * 3 + idz % 3] * in_deriv[idx * ndescrpt * 3 + (idy * 4 + idw) * 3 + idz / 3];
}
atomicAdd(atom_virial + j_idx * 9 + idz, virial_tmp);
}

template<typename FPTYPE>
Expand All @@ -81,12 +82,12 @@ __global__ void virial_deriv_wrt_neighbors_r(
// idz = dd0 * 3 + dd1
// dd0 = idz / 3
// dd1 = idz % 3
const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;
const unsigned int idy = blockIdx.y;
const unsigned int idx = blockIdx.x;
const unsigned int idy = blockIdx.y * blockDim.x + threadIdx.x;
const unsigned int idz = threadIdx.y;
const int ndescrpt = nnei * 1;

if (idx >= nloc) {
if (idy >= nnei) {
return;
}
int j_idx = nlist[idx * nnei + idy];
Expand Down Expand Up @@ -122,9 +123,9 @@ void prod_virial_a_gpu_cuda(
0.0, sizeof(FPTYPE) * 9 * nall));

const int LEN = 16;
int nblock = (nloc + LEN -1) / LEN;
dim3 block_grid(nblock, nnei);
dim3 thread_grid(LEN, 9, 4);
int nblock = (nnei + LEN - 1) / LEN;
dim3 block_grid(nloc, nblock);
dim3 thread_grid(LEN, 9);
// compute virial of a frame
virial_deriv_wrt_neighbors_a<<<block_grid, thread_grid>>>(
virial, atom_virial,
Expand Down Expand Up @@ -155,8 +156,8 @@ void prod_virial_r_gpu_cuda(
0.0, sizeof(FPTYPE) * 9 * nall));

const int LEN = 16;
int nblock = (nloc + LEN -1) / LEN;
dim3 block_grid(nblock, nnei);
int nblock = (nnei + LEN - 1) / LEN;
dim3 block_grid(nloc, nblock);
dim3 thread_grid(LEN, 9);
// compute virial of a frame
virial_deriv_wrt_neighbors_r<<<block_grid, thread_grid>>>(
Expand Down