diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index fab872e721..fbc9a77b56 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -126,6 +126,7 @@ def __init__ (self, self.uniform_seed = uniform_seed self.seed_shift = embedding_net_rand_seed_shift(self.filter_neuron) self.trainable = trainable + self.compress_activation_fn = get_activation_func(activation_function) self.filter_activation_fn = get_activation_func(activation_function) self.filter_precision = get_precision(precision) self.filter_np_precision = get_np_precision(precision) @@ -316,7 +317,8 @@ def enable_compression(self, The overflow check frequency """ self.compress = True - self.table = DPTabulate(model_file, self.type_one_side, self.exclude_types) + self.table = DPTabulate( + model_file, self.type_one_side, self.exclude_types, self.compress_activation_fn) self.table_config = [table_extrapolate, table_stride_1, table_stride_2, check_frequency] self.lower, self.upper \ = self.table.build(min_nbor_dist, diff --git a/deepmd/utils/tabulate.py b/deepmd/utils/tabulate.py index 719697dc87..f1057b38f2 100644 --- a/deepmd/utils/tabulate.py +++ b/deepmd/utils/tabulate.py @@ -2,9 +2,11 @@ import math import logging import numpy as np +from typing import Callable from typing import Tuple, List from deepmd.env import tf from deepmd.env import op_module +from deepmd.common import ACTIVATION_FN_DICT from deepmd.utils.sess import run_sess from deepmd.utils.graph import get_tensor_by_name_from_graph, load_graph_def from deepmd.utils.graph import get_embedding_net_nodes_from_graph_def @@ -30,11 +32,14 @@ class DPTabulate(): exclude_types : List[List[int]] The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1. + activation_function + The activation function in the embedding net. Supported options are {"tanh","gelu"} in common.ACTIVATION_FN_DICT. """ def __init__(self, model_file : str, type_one_side : bool = False, - exclude_types : List[List[int]] = []) -> None: + exclude_types : List[List[int]] = [], + activation_fn : Callable[[tf.Tensor], tf.Tensor] = tf.nn.tanh) -> None: """ Constructor """ @@ -44,6 +49,15 @@ def __init__(self, self.exclude_types = exclude_types if self.type_one_side and len(self.exclude_types) != 0: raise RunTimeError('"type_one_side" is not compatible with "exclude_types"') + + # functype + if activation_fn == ACTIVATION_FN_DICT["tanh"]: + self.functype = 1 + elif activation_fn == ACTIVATION_FN_DICT["gelu"]: + self.functype = 2 + else: + raise RunTimeError("Unknown actication function type!") + self.activation_fn = activation_fn self.graph, self.graph_def = load_graph_def(self.model_file) self.sess = tf.Session(graph = self.graph) @@ -199,26 +213,37 @@ def _make_data(self, xx, idx): xx = tf.reshape(xx, [xx.size, -1]) for layer in range(self.layer_size): if layer == 0: - yy = self._layer_0(xx, self.matrix["layer_" + str(layer + 1)][idx], self.bias["layer_" + str(layer + 1)][idx]) - dy = op_module.unaggregated_dy_dx_s(yy, self.matrix["layer_" + str(layer + 1)][idx]) - dy2 = op_module.unaggregated_dy2_dx_s(yy, dy, self.matrix["layer_" + str(layer + 1)][idx]) + xbar = tf.matmul( + xx, self.matrix["layer_" + str(layer + 1)][idx]) + self.bias["layer_" + str(layer + 1)][idx] + yy = self._layer_0( + xx, self.matrix["layer_" + str(layer + 1)][idx], self.bias["layer_" + str(layer + 1)][idx]) + dy = op_module.unaggregated_dy_dx_s( + yy, self.matrix["layer_" + str(layer + 1)][idx], xbar, tf.constant(self.functype)) + dy2 = op_module.unaggregated_dy2_dx_s( + yy, dy, self.matrix["layer_" + str(layer + 1)][idx], xbar, tf.constant(self.functype)) else: - tt, yy = self._layer_1(yy, self.matrix["layer_" + str(layer + 1)][idx], self.bias["layer_" + str(layer + 1)][idx]) - dz = op_module.unaggregated_dy_dx(yy - tt, self.matrix["layer_" + str(layer + 1)][idx], dy) - dy2 = op_module.unaggregated_dy2_dx(yy - tt, self.matrix["layer_" + str(layer + 1)][idx], dz, dy, dy2) + ybar = tf.matmul( + yy, self.matrix["layer_" + str(layer + 1)][idx]) + self.bias["layer_" + str(layer + 1)][idx] + tt, zz = self._layer_1( + yy, self.matrix["layer_" + str(layer + 1)][idx], self.bias["layer_" + str(layer + 1)][idx]) + dz = op_module.unaggregated_dy_dx( + zz - tt, self.matrix["layer_" + str(layer + 1)][idx], dy, ybar, tf.constant(self.functype)) + dy2 = op_module.unaggregated_dy2_dx( + zz - tt, self.matrix["layer_" + str(layer + 1)][idx], dy, dy2, ybar, tf.constant(self.functype)) dy = dz - - vv = yy.eval() + yy = zz + + vv = zz.eval() dd = dy.eval() d2 = dy2.eval() return vv, dd, d2 def _layer_0(self, x, w, b): - return tf.nn.tanh(tf.matmul(x, w) + b) + return self.activation_fn(tf.matmul(x, w) + b) def _layer_1(self, x, w, b): - t = tf.concat([x, x], axis = 1) - return t, tf.nn.tanh(tf.matmul(x, w) + b) + t + t = tf.concat([x, x], axis=1) + return t, self.activation_fn(tf.matmul(x, w) + b) + t def _save_data(self): for ii in range(self.ntypes * self.ntypes): diff --git a/source/op/unaggregated_grad.cc b/source/op/unaggregated_grad.cc index 343a339a92..89c14a84fb 100644 --- a/source/op/unaggregated_grad.cc +++ b/source/op/unaggregated_grad.cc @@ -1,43 +1,90 @@ #include "custom_op.h" #include "ComputeDescriptor.h" #include "neighbor_list.h" +#include "device.h" + +#define GGELU 0.044715 REGISTER_OP("UnaggregatedDyDxS") .Attr("T: {float, double} = DT_DOUBLE") .Input("y: T") - .Input("w: T") + .Input("w: T") + .Input("xbar: T") + .Input("functype: int32") .Output("dy_dx: T"); REGISTER_OP("UnaggregatedDyDx") .Attr("T: {float, double} = DT_DOUBLE") .Input("z: T") .Input("w: T") - .Input("dy_dx: T") + .Input("dy_dx: T") + .Input("ybar: T") + .Input("functype: int32") .Output("dz_dx: T"); REGISTER_OP("UnaggregatedDy2DxS") .Attr("T: {float, double} = DT_DOUBLE") .Input("y: T") .Input("dy: T") - .Input("w: T") + .Input("w: T") + .Input("xbar: T") + .Input("functype: int32") .Output("dy2_dx: T"); REGISTER_OP("UnaggregatedDy2Dx") .Attr("T: {float, double} = DT_DOUBLE") .Input("z: T") - .Input("w: T") - .Input("dz_dx: T") + .Input("w: T") .Input("dy_dx: T") .Input("dy2_dx: T") + .Input("ybar: T") + .Input("functype: int32") .Output("dz2_dx: T"); +template +FPTYPE grad(const FPTYPE xbar, const FPTYPE y, const int functype) //functype=tanh, gelu, .. +{ + switch (functype) + { + case 1: + return (1 - y * y); + case 2: + { + const FPTYPE var = tanh(SQRT_2_PI * (xbar + GGELU * xbar * xbar * xbar)); + return 0.5 * SQRT_2_PI * xbar * (1 - var * var) * (3 * GGELU * xbar * xbar + 1) + 0.5 * var + 0.5; + } + default: + return -1; + } + +} + +template +FPTYPE grad_grad(const FPTYPE xbar, const FPTYPE y, const int functype) +{ + switch (functype) + { + case 1: + return -2 * y * (1 - y * y); + case 2: + { + const FPTYPE var1 = tanh(SQRT_2_PI * (xbar + GGELU * xbar * xbar * xbar)); + const FPTYPE var2 = SQRT_2_PI * (1 - var1 * var1) * (3 * GGELU * xbar * xbar + 1); + return 3 * GGELU * SQRT_2_PI * xbar * xbar * (1 - var1 * var1) - SQRT_2_PI * xbar * var2 * (3 * GGELU * xbar * xbar + 1) * var1 + var2; + } + default: + return -1; + } +} + + template struct UnaggregatedDyDxSFunctor { - void operator()(const CPUDevice& d, const FPTYPE * y, const FPTYPE * w, const int length, const int width, FPTYPE * dy_dx) { + void operator()(const CPUDevice& d, const FPTYPE * y, const FPTYPE * w, const FPTYPE* xbar, const int length, const int width, FPTYPE * dy_dx, const int functype) { #pragma omp parallel for for (int ii = 0; ii < length; ii++) { for (int jj = 0; jj < width; jj++) { - dy_dx[ii * width + jj] = (1 - y[ii * width + jj] * y[ii * width + jj]) * w[jj]; + dy_dx[ii * width + jj] = grad(xbar[ii * width + jj], y[ii * width + jj],functype)*w[jj]; } } } @@ -53,12 +100,13 @@ struct UnaggregatedDyDxSFunctor { // calculate the gradient for all variables! template struct UnaggregatedDyDxFunctor { - void operator()(const CPUDevice& d, const FPTYPE * z, const FPTYPE * w, const FPTYPE * dy_dx, const int length, const int width, const int size, FPTYPE * dz_dx) { + void operator()(const CPUDevice& d, const FPTYPE * z, const FPTYPE * w, const FPTYPE * dy_dx, const FPTYPE * ybar, const int length, const int width, const int size, FPTYPE * dz_dx, const int functype) { + //width=2*size #pragma omp parallel for for (int kk = 0; kk < length; kk++) { for (int ii = 0; ii < width; ii++) { //FPTYPE dz_drou = 1 - (z[kk * width + ii] - y[kk * size + ii % size]) * (z[kk * width + ii] - y[kk * size + ii % size]); - FPTYPE dz_drou = 1 - z[kk * width + ii] * z[kk * width + ii]; + FPTYPE dz_drou = grad(ybar[kk*width+ii], z[kk * width + ii],functype); FPTYPE accumulator = 0.0; for (int jj = 0; jj < size; jj++) { accumulator += w[jj * width + ii] * dy_dx[kk * size + jj]; @@ -80,11 +128,11 @@ struct UnaggregatedDyDxFunctor { template struct UnaggregatedDy2DxSFunctor { - void operator()(const CPUDevice& d, const FPTYPE * y, const FPTYPE * dy, const FPTYPE * w, const int length, const int width, FPTYPE * dy2_dx) { + void operator()(const CPUDevice& d, const FPTYPE * y, const FPTYPE * dy, const FPTYPE * w, const FPTYPE* xbar, const int length, const int width, FPTYPE * dy2_dx, const int functype) { #pragma omp parallel for for (int ii = 0; ii < length; ii++) { for (int jj = 0; jj < width; jj++) { - dy2_dx[ii * width + jj] = -2 * w[jj] * y[ii * width + jj] * dy[ii * width + jj]; + dy2_dx[ii * width + jj] = grad_grad(xbar[ii * width + jj],y[ii * width + jj],functype)*w[jj]*w[jj]; } } } @@ -100,12 +148,12 @@ struct UnaggregatedDy2DxSFunctor { // calculate the gradient for all variables! template struct UnaggregatedDy2DxFunctor { - void operator()(const CPUDevice& d, const FPTYPE * z, const FPTYPE * w, const FPTYPE * dz_dx, const FPTYPE * dy_dx, const FPTYPE * dy2_dx, const int length, const int width, const int size, FPTYPE * dz2_dx) { + void operator()(const CPUDevice& d, const FPTYPE * z, const FPTYPE * w, const FPTYPE * dy_dx, const FPTYPE * dy2_dx, const FPTYPE * ybar, const int length, const int width, const int size, FPTYPE * dz2_dx, const int functype) { #pragma omp parallel for for (int kk = 0; kk < length; kk++) { for (int ii = 0; ii < width; ii++) { //FPTYPE dz_drou = 1 - (z[kk * width + ii] - y[kk * size + ii % size]) * (z[kk * width + ii] - y[kk * size + ii % size]); - FPTYPE dz_drou = 1 - z[kk * width + ii] * z[kk * width + ii]; + FPTYPE dz_drou = grad(ybar[kk*width+ii], z[kk * width + ii],functype); FPTYPE accumulator = 0.0; for (int jj = 0; jj < size; jj++) { accumulator += w[jj * width + ii] * dy2_dx[kk * size + jj]; @@ -115,7 +163,7 @@ struct UnaggregatedDy2DxFunctor { for (int jj = 0; jj < size; jj++) { accumulator += w[jj * width + ii] * dy_dx[kk * size + jj]; } - dz_drou -= 2 * z[kk * width + ii] * (dz_dx[kk * width + ii] - dy_dx[kk * size + ii % size]) * accumulator; + dz_drou += grad_grad(ybar[kk * width + ii], z[kk * width + ii],functype) * accumulator * accumulator; dz_drou += dy2_dx[kk * size + ii % size]; dz2_dx[kk * width + ii] = dz_drou; } @@ -141,13 +189,18 @@ class UnaggregatedDyDxSOp : public OpKernel { void _Compute(OpKernelContext* context) { // Grab the input tensor + //xbar=xw+b int context_input_index = 0; const Tensor& y = context->input(context_input_index++); const Tensor& w = context->input(context_input_index++); + const Tensor& xbar = context->input(context_input_index++); + const Tensor& functype = context->input(context_input_index++); // set size of the sample - OP_REQUIRES (context, (y.shape().dims() == 2), errors::InvalidArgument ("Dim of table should be 1")); + OP_REQUIRES (context, (y.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2")); OP_REQUIRES (context, (w.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2")); + OP_REQUIRES(context, (xbar.shape().dims() == 2), errors::InvalidArgument("Dim of input should be 2")); + //check functype int context_output_index = 0; Tensor* dy_dx = NULL; @@ -159,9 +212,11 @@ class UnaggregatedDyDxSOp : public OpKernel { context->eigen_device(), // define actually graph execution device y.flat().data(), w.flat().data(), + xbar.flat().data(), y.shape().dim_size(0), y.shape().dim_size(1), - dy_dx->flat().data() + dy_dx->flat().data(), + functype.flat()(0) ); } private: @@ -182,14 +237,17 @@ class UnaggregatedDy2DxSOp : public OpKernel { const Tensor& y = context->input(context_input_index++); const Tensor& dy = context->input(context_input_index++); const Tensor& w = context->input(context_input_index++); + const Tensor& xbar = context->input(context_input_index++); + const Tensor& functype = context->input(context_input_index++); // set size of the sample OP_REQUIRES (context, (y.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2")); OP_REQUIRES (context, (dy.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2")); OP_REQUIRES (context, (w.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2")); + OP_REQUIRES (context, (xbar.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2")); int context_output_index = 0; - Tensor* dy2_dx = NULL; + Tensor* dy2_dx = NULL; OP_REQUIRES_OK(context, context->allocate_output(context_output_index++, y.shape(), &dy2_dx)); @@ -199,9 +257,11 @@ class UnaggregatedDy2DxSOp : public OpKernel { y.flat().data(), dy.flat().data(), w.flat().data(), + xbar.flat().data(), y.shape().dim_size(0), y.shape().dim_size(1), - dy2_dx->flat().data() + dy2_dx->flat().data(), + functype.flat()(0) ); } private: @@ -222,11 +282,14 @@ class UnaggregatedDyDxOp : public OpKernel { const Tensor& z = context->input(context_input_index++); const Tensor& w = context->input(context_input_index++); const Tensor& dy_dx = context->input(context_input_index++); + const Tensor& ybar = context->input(context_input_index++); + const Tensor& functype = context->input(context_input_index++); // set size of the sample - OP_REQUIRES (context, (z.shape().dims() == 2), errors::InvalidArgument ("Dim of table should be 1")); + OP_REQUIRES (context, (z.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2")); OP_REQUIRES (context, (w.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2")); OP_REQUIRES (context, (dy_dx.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2")); + OP_REQUIRES (context, (ybar.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2")); int context_output_index = 0; Tensor* dz_dx = NULL; @@ -239,10 +302,12 @@ class UnaggregatedDyDxOp : public OpKernel { z.flat().data(), w.flat().data(), dy_dx.flat().data(), + ybar.flat().data(), z.shape().dim_size(0), - z.shape().dim_size(1), - w.shape().dim_size(0), - dz_dx->flat().data() + z.shape().dim_size(1), //N1 + w.shape().dim_size(0), //N0 , N1=2N0 + dz_dx->flat().data(), + functype.flat()(0) ); } private: @@ -262,16 +327,17 @@ class UnaggregatedDy2DxOp : public OpKernel { int context_input_index = 0; const Tensor& z = context->input(context_input_index++); const Tensor& w = context->input(context_input_index++); - const Tensor& dz_dx = context->input(context_input_index++); const Tensor& dy_dx = context->input(context_input_index++); const Tensor& dy2_dx = context->input(context_input_index++); + const Tensor& ybar = context->input(context_input_index++); + const Tensor& functype = context->input(context_input_index++); // set size of the sample OP_REQUIRES (context, (z.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2")); OP_REQUIRES (context, (w.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2")); - OP_REQUIRES (context, (dz_dx.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2")); OP_REQUIRES (context, (dy_dx.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2")); OP_REQUIRES (context, (dy2_dx.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2")); + OP_REQUIRES (context, (ybar.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2")); int context_output_index = 0; Tensor* dz2_dx = NULL; @@ -283,13 +349,14 @@ class UnaggregatedDy2DxOp : public OpKernel { context->eigen_device(), // define actually graph execution device z.flat().data(), w.flat().data(), - dz_dx.flat().data(), dy_dx.flat().data(), dy2_dx.flat().data(), + ybar.flat().data(), z.shape().dim_size(0), z.shape().dim_size(1), w.shape().dim_size(0), - dz2_dx->flat().data() + dz2_dx->flat().data(), + functype.flat()(0) ); } private: diff --git a/source/tests/test_tabulate.py b/source/tests/test_tabulate.py new file mode 100644 index 0000000000..ce26c4e3e6 --- /dev/null +++ b/source/tests/test_tabulate.py @@ -0,0 +1,52 @@ +import unittest +import numpy as np +from deepmd.utils.tabulate import DPTabulate +from deepmd.env import op_module +from deepmd.env import tf +from deepmd.common import gelu + +# Now just test some OPs utilized by DPTabulate sourced in /opt/deepmd-kit/source/op/unaggregated_grad.cc + +class TestDPTabulate(unittest.TestCase): + def test_op_tanh(self): + w=tf.constant([[0.1,0.2,0.3,0.4],[0.5,0.6,0.7,0.8],[0.9,1,1.1,1.2]],dtype='double') + x=tf.constant([[0.1,0.2,0.3],[0.4,0.5,0.6],[0.7,0.8,0.9],[1.0,1.1,1.2]],dtype='double') + b=tf.constant([[0.1],[0.2],[0.3],[0.4]],dtype='double') + xbar = tf.matmul(x, w) + b + y=tf.nn.tanh(xbar) + dy = op_module.unaggregated_dy_dx_s(y, w, xbar, tf.constant(1)) + dy_array = tf.Session().run(dy) + answer = np.array([[8.008666403121351973e-02, 1.513925729426658651e-01, 2.134733287761668430e-01, 2.661983049806041501e-01], + [4.010658815015744061e-02, 6.306476628799793926e-02, 7.332167904608145881e-02, 7.494218676568849269e-02], + [1.561705624394135218e-02, 1.994112926507514427e-02, 1.887519955881525671e-02, 1.576442161040989692e-02], + [5.492686739421748753e-03, 5.754985286040992763e-03, 4.493113544969218158e-03, 3.107638130764600777e-03]]) + + places = 18 + for ii in range(dy_array.shape[0]): + for jj in range(dy_array.shape[1]): + self.assertAlmostEqual(dy_array[ii,jj], answer[ii,jj], places=places) + + def test_op_gelu(self): + w = tf.constant([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [ + 0.9, 1, 1.1, 1.2]], dtype='double') + x = tf.constant([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [ + 0.7, 0.8, 0.9], [1.0, 1.1, 1.2]], dtype='double') + b = tf.constant([[0.1], [0.2], [0.3], [0.4]], dtype='double') + xbar = tf.matmul(x, w) + b + y = gelu(xbar) + dy = op_module.unaggregated_dy_dx_s(y, w, xbar, tf.constant(2)) + dy_array = tf.Session().run(dy) + answer = np.array([[8.549286163555620821e-02, 1.782905778685600906e-01, 2.776474599997448833e-01, 3.827650237273348965e-01], + [1.089906023807040714e-01, 2.230820937721638697e-01, 3.381867859682909927e-01, 4.513008399758057232e-01], + [1.124254240556722684e-01, 2.209918074710395253e-01, 3.238894323148118759e-01, 4.220357318198978414e-01], + [1.072173273655498138e-01, 2.082159073100979807e-01, 3.059816075270163083e-01, 4.032981557798429595e-01]]) + + places = 18 + for ii in range(dy_array.shape[0]): + for jj in range(dy_array.shape[1]): + self.assertAlmostEqual(dy_array[ii, jj], answer[ii, jj], places=places) + + + +if __name__ == '__main__': + unittest.main()