Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

decouple activation function's type from model compression's process in SE_A, now tanh & gelu is both available. #1020

Merged
merged 8 commits into from
Aug 27, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __init__ (self,
self.uniform_seed = uniform_seed
self.seed_shift = embedding_net_rand_seed_shift(self.filter_neuron)
self.trainable = trainable
self.compress_activation_fn = get_activation_func(activation_function)
self.filter_activation_fn = get_activation_func(activation_function)
self.filter_precision = get_precision(precision)
self.filter_np_precision = get_np_precision(precision)
Expand Down Expand Up @@ -316,7 +317,8 @@ def enable_compression(self,
The overflow check frequency
"""
self.compress = True
self.table = DPTabulate(model_file, self.type_one_side, self.exclude_types)
self.table = DPTabulate(
model_file, self.type_one_side, self.exclude_types, self.compress_activation_fn)
self.table_config = [table_extrapolate, table_stride_1, table_stride_2, check_frequency]
self.lower, self.upper \
= self.table.build(min_nbor_dist,
Expand Down
45 changes: 33 additions & 12 deletions deepmd/utils/tabulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ class DPTabulate():
def __init__(self,
model_file : str,
type_one_side : bool = False,
exclude_types : List[List[int]] = []) -> None:
exclude_types : List[List[int]] = [],
activation_fn=tf.nn.tanh) -> None:
njzjz marked this conversation as resolved.
Show resolved Hide resolved
"""
Constructor
"""
Expand All @@ -44,6 +45,15 @@ def __init__(self,
self.exclude_types = exclude_types
if self.type_one_side and len(self.exclude_types) != 0:
raise RunTimeError('"type_one_side" is not compatible with "exclude_types"')

# functype
if activation_fn.__name__ == 'tf.nn.tanh' or activation_fn.__name__ == 'tanh':
njzjz marked this conversation as resolved.
Show resolved Hide resolved
self.functype = 1
elif activation_fn.__name__ == 'gelu':
self.functype = 2
else:
raise RunTimeError("Unknown actication function type!")
self.activation_fn = activation_fn

self.graph, self.graph_def = load_graph_def(self.model_file)
self.sess = tf.Session(graph = self.graph)
Expand Down Expand Up @@ -199,26 +209,37 @@ def _make_data(self, xx, idx):
xx = tf.reshape(xx, [xx.size, -1])
for layer in range(self.layer_size):
if layer == 0:
yy = self._layer_0(xx, self.matrix["layer_" + str(layer + 1)][idx], self.bias["layer_" + str(layer + 1)][idx])
dy = op_module.unaggregated_dy_dx_s(yy, self.matrix["layer_" + str(layer + 1)][idx])
dy2 = op_module.unaggregated_dy2_dx_s(yy, dy, self.matrix["layer_" + str(layer + 1)][idx])
xbar = tf.matmul(
xx, self.matrix["layer_" + str(layer + 1)][idx]) + self.bias["layer_" + str(layer + 1)][idx]
yy = self._layer_0(
xx, self.matrix["layer_" + str(layer + 1)][idx], self.bias["layer_" + str(layer + 1)][idx])
dy = op_module.unaggregated_dy_dx_s(
yy, self.matrix["layer_" + str(layer + 1)][idx], xbar, tf.constant(self.functype))
dy2 = op_module.unaggregated_dy2_dx_s(
yy, dy, self.matrix["layer_" + str(layer + 1)][idx], xbar, tf.constant(self.functype))
else:
tt, yy = self._layer_1(yy, self.matrix["layer_" + str(layer + 1)][idx], self.bias["layer_" + str(layer + 1)][idx])
dz = op_module.unaggregated_dy_dx(yy - tt, self.matrix["layer_" + str(layer + 1)][idx], dy)
dy2 = op_module.unaggregated_dy2_dx(yy - tt, self.matrix["layer_" + str(layer + 1)][idx], dz, dy, dy2)
ybar = tf.matmul(
yy, self.matrix["layer_" + str(layer + 1)][idx]) + self.bias["layer_" + str(layer + 1)][idx]
tt, zz = self._layer_1(
yy, self.matrix["layer_" + str(layer + 1)][idx], self.bias["layer_" + str(layer + 1)][idx])
dz = op_module.unaggregated_dy_dx(
zz - tt, self.matrix["layer_" + str(layer + 1)][idx], dy, ybar, tf.constant(self.functype))
dy2 = op_module.unaggregated_dy2_dx(
zz - tt, self.matrix["layer_" + str(layer + 1)][idx], dy, dy2, ybar, tf.constant(self.functype))
dy = dz

vv = yy.eval()
yy = zz

vv = zz.eval()
dd = dy.eval()
d2 = dy2.eval()
return vv, dd, d2

def _layer_0(self, x, w, b):
return tf.nn.tanh(tf.matmul(x, w) + b)
return self.activation_fn(tf.matmul(x, w) + b)

def _layer_1(self, x, w, b):
t = tf.concat([x, x], axis = 1)
return t, tf.nn.tanh(tf.matmul(x, w) + b) + t
t = tf.concat([x, x], axis=1)
return t, self.activation_fn(tf.matmul(x, w) + b) + t

def _save_data(self):
for ii in range(self.ntypes * self.ntypes):
Expand Down
122 changes: 97 additions & 25 deletions source/op/unaggregated_grad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,90 @@
#include "ComputeDescriptor.h"
#include "neighbor_list.h"


#define SQRT2_PI 0.7978845608028654
njzjz marked this conversation as resolved.
Show resolved Hide resolved
#define GGELU 0.044715

REGISTER_OP("UnaggregatedDyDxS")
.Attr("T: {float, double} = DT_DOUBLE")
.Input("y: T")
.Input("w: T")
.Input("w: T")
.Input("xbar: T")
.Input("functype: int32")
.Output("dy_dx: T");

REGISTER_OP("UnaggregatedDyDx")
.Attr("T: {float, double} = DT_DOUBLE")
.Input("z: T")
.Input("w: T")
.Input("dy_dx: T")
.Input("dy_dx: T")
.Input("ybar: T")
.Input("functype: int32")
.Output("dz_dx: T");

REGISTER_OP("UnaggregatedDy2DxS")
.Attr("T: {float, double} = DT_DOUBLE")
.Input("y: T")
.Input("dy: T")
.Input("w: T")
.Input("w: T")
.Input("xbar: T")
.Input("functype: int32")
.Output("dy2_dx: T");

REGISTER_OP("UnaggregatedDy2Dx")
.Attr("T: {float, double} = DT_DOUBLE")
.Input("z: T")
.Input("w: T")
.Input("dz_dx: T")
.Input("w: T")
.Input("dy_dx: T")
.Input("dy2_dx: T")
.Input("ybar: T")
.Input("functype: int32")
.Output("dz2_dx: T");
template <typename FPTYPE>
FPTYPE grad(const FPTYPE xbar, const FPTYPE y, const int functype) //functype=tanh, gelu, ..
{
switch (functype)
{
case 1:
return (1 - y * y);
case 2:
{
const FPTYPE var = tanh(SQRT2_PI * (xbar + GGELU * xbar * xbar * xbar));
return 0.5 * SQRT2_PI * xbar * (1 - var * var) * (3 * GGELU * xbar * xbar + 1) + 0.5 * var + 0.5;
}
default:
return -1;
}

}

template <typename FPTYPE>
FPTYPE grad_grad(const FPTYPE xbar, const FPTYPE y, const int functype)
{
switch (functype)
{
case 1:
return -2 * y * (1 - y * y);
case 2:
{
const FPTYPE var1 = tanh(SQRT2_PI * (xbar + GGELU * xbar * xbar * xbar));
const FPTYPE var2 = SQRT2_PI * (1 - var1 * var1) * (3 * GGELU * xbar * xbar + 1);
return 3 * GGELU * SQRT2_PI * xbar * xbar * (1 - var1 * var1) - SQRT2_PI * xbar * var2 * (3 * GGELU * xbar * xbar + 1) * var1 + var2;
}
default:
return -1;
}
}



template <typename FPTYPE>
struct UnaggregatedDyDxSFunctor {
void operator()(const CPUDevice& d, const FPTYPE * y, const FPTYPE * w, const int length, const int width, FPTYPE * dy_dx) {
void operator()(const CPUDevice& d, const FPTYPE * y, const FPTYPE * w, const FPTYPE* xbar, const int length, const int width, FPTYPE * dy_dx, const int functype) {
#pragma omp parallel for
for (int ii = 0; ii < length; ii++) {
for (int jj = 0; jj < width; jj++) {
dy_dx[ii * width + jj] = (1 - y[ii * width + jj] * y[ii * width + jj]) * w[jj];
dy_dx[ii * width + jj] = grad(xbar[ii * width + jj], y[ii * width + jj],functype)*w[jj];
njzjz marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
Expand All @@ -53,12 +101,13 @@ struct UnaggregatedDyDxSFunctor {
// calculate the gradient for all variables!
template <typename FPTYPE>
struct UnaggregatedDyDxFunctor {
void operator()(const CPUDevice& d, const FPTYPE * z, const FPTYPE * w, const FPTYPE * dy_dx, const int length, const int width, const int size, FPTYPE * dz_dx) {
void operator()(const CPUDevice& d, const FPTYPE * z, const FPTYPE * w, const FPTYPE * dy_dx, const FPTYPE * ybar, const int length, const int width, const int size, FPTYPE * dz_dx, const int functype) {
//width=2*size
#pragma omp parallel for
for (int kk = 0; kk < length; kk++) {
for (int ii = 0; ii < width; ii++) {
//FPTYPE dz_drou = 1 - (z[kk * width + ii] - y[kk * size + ii % size]) * (z[kk * width + ii] - y[kk * size + ii % size]);
FPTYPE dz_drou = 1 - z[kk * width + ii] * z[kk * width + ii];
FPTYPE dz_drou = grad(ybar[kk*width+ii], z[kk * width + ii],functype);
FPTYPE accumulator = 0.0;
for (int jj = 0; jj < size; jj++) {
accumulator += w[jj * width + ii] * dy_dx[kk * size + jj];
Expand All @@ -80,11 +129,11 @@ struct UnaggregatedDyDxFunctor {

template <typename FPTYPE>
struct UnaggregatedDy2DxSFunctor {
void operator()(const CPUDevice& d, const FPTYPE * y, const FPTYPE * dy, const FPTYPE * w, const int length, const int width, FPTYPE * dy2_dx) {
void operator()(const CPUDevice& d, const FPTYPE * y, const FPTYPE * dy, const FPTYPE * w, const FPTYPE* xbar, const int length, const int width, FPTYPE * dy2_dx, const int functype) {
#pragma omp parallel for
for (int ii = 0; ii < length; ii++) {
for (int jj = 0; jj < width; jj++) {
dy2_dx[ii * width + jj] = -2 * w[jj] * y[ii * width + jj] * dy[ii * width + jj];
dy2_dx[ii * width + jj] = grad_grad(xbar[ii * width + jj],y[ii * width + jj],functype)*w[jj]*w[jj];
}
}
}
Expand All @@ -100,12 +149,12 @@ struct UnaggregatedDy2DxSFunctor {
// calculate the gradient for all variables!
template <typename FPTYPE>
struct UnaggregatedDy2DxFunctor {
void operator()(const CPUDevice& d, const FPTYPE * z, const FPTYPE * w, const FPTYPE * dz_dx, const FPTYPE * dy_dx, const FPTYPE * dy2_dx, const int length, const int width, const int size, FPTYPE * dz2_dx) {
void operator()(const CPUDevice& d, const FPTYPE * z, const FPTYPE * w, const FPTYPE * dy_dx, const FPTYPE * dy2_dx, const FPTYPE * ybar, const int length, const int width, const int size, FPTYPE * dz2_dx, const int functype) {
#pragma omp parallel for
for (int kk = 0; kk < length; kk++) {
for (int ii = 0; ii < width; ii++) {
//FPTYPE dz_drou = 1 - (z[kk * width + ii] - y[kk * size + ii % size]) * (z[kk * width + ii] - y[kk * size + ii % size]);
FPTYPE dz_drou = 1 - z[kk * width + ii] * z[kk * width + ii];
FPTYPE dz_drou = grad(ybar[kk*width+ii], z[kk * width + ii],functype);
FPTYPE accumulator = 0.0;
for (int jj = 0; jj < size; jj++) {
accumulator += w[jj * width + ii] * dy2_dx[kk * size + jj];
Expand All @@ -115,7 +164,7 @@ struct UnaggregatedDy2DxFunctor {
for (int jj = 0; jj < size; jj++) {
accumulator += w[jj * width + ii] * dy_dx[kk * size + jj];
}
dz_drou -= 2 * z[kk * width + ii] * (dz_dx[kk * width + ii] - dy_dx[kk * size + ii % size]) * accumulator;
dz_drou += grad_grad(ybar[kk * width + ii], z[kk * width + ii],functype) * accumulator * accumulator;
dz_drou += dy2_dx[kk * size + ii % size];
dz2_dx[kk * width + ii] = dz_drou;
}
Expand All @@ -141,13 +190,18 @@ class UnaggregatedDyDxSOp : public OpKernel {

void _Compute(OpKernelContext* context) {
// Grab the input tensor
//xbar=xw+b
int context_input_index = 0;
const Tensor& y = context->input(context_input_index++);
const Tensor& w = context->input(context_input_index++);
const Tensor& xbar = context->input(context_input_index++);
const Tensor& functype = context->input(context_input_index++);

// set size of the sample
OP_REQUIRES (context, (y.shape().dims() == 2), errors::InvalidArgument ("Dim of table should be 1"));
OP_REQUIRES (context, (y.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
OP_REQUIRES (context, (w.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
OP_REQUIRES(context, (xbar.shape().dims() == 2), errors::InvalidArgument("Dim of input should be 2"));
//check functype

int context_output_index = 0;
Tensor* dy_dx = NULL;
Expand All @@ -159,11 +213,14 @@ class UnaggregatedDyDxSOp : public OpKernel {
context->eigen_device<Device>(), // define actually graph execution device
y.flat<FPTYPE>().data(),
w.flat<FPTYPE>().data(),
xbar.flat<FPTYPE>().data(),
y.shape().dim_size(0),
y.shape().dim_size(1),
dy_dx->flat<FPTYPE>().data()
dy_dx->flat<FPTYPE>().data(),
functype.flat<int32>()(0)
);
}

private:
};

Expand All @@ -182,11 +239,14 @@ class UnaggregatedDy2DxSOp : public OpKernel {
const Tensor& y = context->input(context_input_index++);
const Tensor& dy = context->input(context_input_index++);
const Tensor& w = context->input(context_input_index++);
const Tensor& xbar = context->input(context_input_index++);
const Tensor& functype = context->input(context_input_index++);

// set size of the sample
OP_REQUIRES (context, (y.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
OP_REQUIRES (context, (dy.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
OP_REQUIRES (context, (w.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
OP_REQUIRES (context, (xbar.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));

int context_output_index = 0;
Tensor* dy2_dx = NULL;
Expand All @@ -199,11 +259,14 @@ class UnaggregatedDy2DxSOp : public OpKernel {
y.flat<FPTYPE>().data(),
dy.flat<FPTYPE>().data(),
w.flat<FPTYPE>().data(),
xbar.flat<FPTYPE>().data(),
y.shape().dim_size(0),
y.shape().dim_size(1),
dy2_dx->flat<FPTYPE>().data()
dy2_dx->flat<FPTYPE>().data(),
functype.flat<int32>()(0)
);
}

private:
};

Expand All @@ -222,11 +285,14 @@ class UnaggregatedDyDxOp : public OpKernel {
const Tensor& z = context->input(context_input_index++);
const Tensor& w = context->input(context_input_index++);
const Tensor& dy_dx = context->input(context_input_index++);
const Tensor& ybar = context->input(context_input_index++);
const Tensor& functype = context->input(context_input_index++);

// set size of the sample
OP_REQUIRES (context, (z.shape().dims() == 2), errors::InvalidArgument ("Dim of table should be 1"));
OP_REQUIRES (context, (z.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
OP_REQUIRES (context, (w.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
OP_REQUIRES (context, (dy_dx.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
OP_REQUIRES (context, (ybar.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));

int context_output_index = 0;
Tensor* dz_dx = NULL;
Expand All @@ -239,12 +305,15 @@ class UnaggregatedDyDxOp : public OpKernel {
z.flat<FPTYPE>().data(),
w.flat<FPTYPE>().data(),
dy_dx.flat<FPTYPE>().data(),
ybar.flat<FPTYPE>().data(),
z.shape().dim_size(0),
z.shape().dim_size(1),
w.shape().dim_size(0),
dz_dx->flat<FPTYPE>().data()
z.shape().dim_size(1), //N1
w.shape().dim_size(0), //N0 , N1=2N0
dz_dx->flat<FPTYPE>().data(),
functype.flat<int32>()(0)
);
}

private:
};

Expand All @@ -262,16 +331,17 @@ class UnaggregatedDy2DxOp : public OpKernel {
int context_input_index = 0;
const Tensor& z = context->input(context_input_index++);
const Tensor& w = context->input(context_input_index++);
const Tensor& dz_dx = context->input(context_input_index++);
const Tensor& dy_dx = context->input(context_input_index++);
const Tensor& dy2_dx = context->input(context_input_index++);
const Tensor& ybar = context->input(context_input_index++);
const Tensor& functype = context->input(context_input_index++);

// set size of the sample
OP_REQUIRES (context, (z.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
OP_REQUIRES (context, (w.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
OP_REQUIRES (context, (dz_dx.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
OP_REQUIRES (context, (dy_dx.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
OP_REQUIRES (context, (dy2_dx.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
OP_REQUIRES (context, (ybar.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));

int context_output_index = 0;
Tensor* dz2_dx = NULL;
Expand All @@ -283,15 +353,17 @@ class UnaggregatedDy2DxOp : public OpKernel {
context->eigen_device<Device>(), // define actually graph execution device
z.flat<FPTYPE>().data(),
w.flat<FPTYPE>().data(),
dz_dx.flat<FPTYPE>().data(),
dy_dx.flat<FPTYPE>().data(),
dy2_dx.flat<FPTYPE>().data(),
ybar.flat<FPTYPE>().data(),
z.shape().dim_size(0),
z.shape().dim_size(1),
w.shape().dim_size(0),
dz2_dx->flat<FPTYPE>().data()
dz2_dx->flat<FPTYPE>().data(),
functype.flat<int32>()(0)
);
}

private:
};

Expand Down
Loading