Skip to content

Commit

Permalink
fix model compression error (#1043)
Browse files Browse the repository at this point in the history
* fix model compression error

* add doc for model compression limitation
  • Loading branch information
denghuilu authored Aug 27, 2021
1 parent 5ab5fa1 commit 37af1d1
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 6 deletions.
3 changes: 2 additions & 1 deletion doc/train/gpu-limitations.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ If you use deepmd-kit in a GPU environment, the acceptable value range of some v
1. The number of atom type of a given system must be less than 128.
2. The maximum distance between an atom and it's neighbors must be less than 128. It can be controlled by setting the rcut value of training parameters.
3. Theoretically, the maximum number of atoms that a single GPU can accept is about 10,000,000. However, this value is actually limited by the GPU memory size currently, usually within 1000,000 atoms even at the model compression mode.
4. The total sel value of training parameters(in model/descriptor section) must be less than 4096.
4. The total sel value of training parameters(in model/descriptor section) must be less than 4096.
5. The size of the last layer of embedding net must be less than 1024 during the model compression process.
4 changes: 2 additions & 2 deletions source/lib/src/cuda/tabulate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ __global__ void tabulate_fusion_grad_fifth_order_polynomial(
bool unloop = false;
FPTYPE * iteratorA = (FPTYPE *)&_data[0]; // dy
for (int ii = 0; ii < MTILE; ii++) {
if (thread_idx < last_layer_size) {
iteratorA[ii * last_layer_size + thread_idx] = dy[block_idx * MTILE * last_layer_size + ii * last_layer_size + thread_idx];
for (int jj = thread_idx; jj < last_layer_size; jj += blockDim.x) {
iteratorA[ii * last_layer_size + jj] = dy[block_idx * MTILE * last_layer_size + ii * last_layer_size + jj];
}
}
__syncthreads();
Expand Down
5 changes: 2 additions & 3 deletions source/lib/src/rocm/tabulate.hip.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#define TPB 256
#define WARP_SIZE 64
#define FULL_MASK 0xffffffff
#include "gpu_rocm.h"

template <typename FPTYPE>
__forceinline__ __device__
Expand Down Expand Up @@ -140,8 +139,8 @@ __global__ void tabulate_fusion_grad_fifth_order_polynomial(
bool unloop = false;
FPTYPE * iteratorA = (FPTYPE *)&_data[0]; // dy
for (int ii = 0; ii < MTILE; ii++) {
if (thread_idx < last_layer_size) {
iteratorA[ii * last_layer_size + thread_idx] = dy[block_idx * MTILE * last_layer_size + ii * last_layer_size + thread_idx];
for (int jj = thread_idx; jj < last_layer_size; jj += blockDim.x) {
iteratorA[ii * last_layer_size + jj] = dy[block_idx * MTILE * last_layer_size + ii * last_layer_size + jj];
}
}
__syncthreads();
Expand Down
1 change: 1 addition & 0 deletions source/op/tabulate_multi_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ class TabulateFusionGradGradOp : public OpKernel {
dz_dy,
table, table_info, em_x, em, dz_dy_dem_x, dz_dy_dem, nloc, nnei, last_layer_size);
#endif // TENSORFLOW_USE_ROCM
OP_REQUIRES (context, (last_layer_size <= 1024), errors::InvalidArgument ("In the process of model compression, the size of the last layer of embedding net must be less than 1024!"));
}
else if (device == "CPU") {
deepmd::tabulate_fusion_grad_grad_cpu(
Expand Down

0 comments on commit 37af1d1

Please sign in to comment.