Skip to content

Commit

Permalink
fix compiling warning.
Browse files Browse the repository at this point in the history
  • Loading branch information
qingqing01 committed Oct 26, 2017
1 parent bcc0dad commit bd680f1
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 34 deletions.
4 changes: 2 additions & 2 deletions paddle/operators/lstm_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
auto* batch_cell_pre_act = ctx.Input<LoDTensor>("BatchCellPreAct");

auto* hidden_g = ctx.Input<LoDTensor>(framework::GradVarName("Hidden"));
auto* cell_g = ctx.Input<LoDTensor>(framework::GradVarName("Cell"));
// auto* cell_g = ctx.Input<LoDTensor>(framework::GradVarName("Cell"));

auto* in_g = ctx.Output<LoDTensor>(framework::GradVarName("Input"));
auto* weight_g = ctx.Output<Tensor>(framework::GradVarName("Weight"));
Expand Down Expand Up @@ -219,8 +219,8 @@ class LSTMGradKernel : public framework::OpKernel<T> {
LoDTensor batch_cell_g;
batch_cell_g.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell_g.set_lod(batch_gate->lod());
to_batch(device_ctx, *cell_g, batch_cell_g, false);
// TODO(qingqing) support the case output cell has gradient.
// to_batch(device_ctx, *cell_g, batch_cell_g, false);
zero(device_ctx, &batch_cell_g, static_cast<T>(0.0));

LoDTensor batch_gate_g;
Expand Down
7 changes: 4 additions & 3 deletions paddle/operators/math/sequence2batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ class LoDTensor2BatchFunctor {
if (!is_cal_batch_lod) {
auto lods = batch.lod();
PADDLE_ENFORCE_EQ(lods.size(), 2UL);
PADDLE_ENFORCE_EQ(lods[1].size(), lod_tensor.dims()[0]);
PADDLE_ENFORCE_EQ(lods[1].size(),
static_cast<size_t>(lod_tensor.dims()[0]));
CopyMatrixRowsFunctor<Place, T> to_batch;
to_batch(context, lod_tensor, lods[1].data(), batch, true);
return;
Expand Down Expand Up @@ -111,10 +112,10 @@ class LoDTensor2BatchFunctor {
size_t* batch_starts = batch_lods[0].data();
size_t* seq2batch_idx = batch_lods[1].data();
batch_starts[0] = 0;
for (size_t n = 0; n < num_batch; n++) {
for (int n = 0; n < num_batch; n++) {
auto batch_id = static_cast<int>(batch_starts[n]);
for (size_t i = 0; i < seq_info.size(); ++i) {
size_t seq_len = seq_info[i].length;
int seq_len = seq_info[i].length;
int start = seq_info[i].start;
if (n < seq_len) {
seq2batch_idx[batch_id] =
Expand Down
46 changes: 17 additions & 29 deletions python/paddle/v2/framework/tests/test_lstm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,24 +52,23 @@ def _step(x, w_h, w_c, h_pre, c_pre, act_gate, act_cell, act_cand):
g = np.dot(h_pre, w_h) # 1 x 4D
g = g + x
g = np.reshape(g, (1, g.size))
c_tmp, g_i, g_f, g_o = np.split(g, 4, axis=1)
c, g_i, g_f, g_o = np.split(g, 4, axis=1)
if w_c is None:
g_i = act_gate(g_i) # 1 x D
g_f = act_gate(g_f) # 1 x D
else:
w_ic, w_fc, w_oc = np.split(w_c, 3, axis=1)
g_i = act_gate(g_i + w_ic * c_pre) # 1 x D
g_f = act_gate(g_f + w_fc * c_pre) # 1 x D
c = g_f * c_pre + g_i * act_cand(c_tmp) # 1 x D
c = g_f * c_pre + g_i * act_cand(c) # 1 x D

if w_c is None:
g_o = act_gate(g_o) # 1 x D
else:
_, _, w_oc = np.split(w_c, 3, axis=1)
g_o = act_gate(g_o + w_oc * c) # 1 x D
h = g_o * act_cell(c)
bg = np.concatenate((act_cand(c_tmp), g_i, g_f, g_o), axis=1)
return h, c, bg
return h, c

def _reverse(x, lod):
y = np.zeros_like(x)
Expand All @@ -82,7 +81,6 @@ def _reverse(x, lod):
batch_size = len(offset) - 1
hidden = []
cell = []
gate = []
input = _reverse(input, offset) if is_reverse else input
if w_b is not None:
input = input + np.tile(w_b, (offset[-1], 1))
Expand All @@ -94,30 +92,26 @@ def _reverse(x, lod):
c_pre = c0[i] # 1 x D
for j in range(seq_len):
# compute one step
h_pre, c_pre, g_pre = _step(x[j], w_h, w_c, h_pre, c_pre, act_gate,
act_cell, act_cand)
h_pre, c_pre = _step(x[j], w_h, w_c, h_pre, c_pre, act_gate,
act_cell, act_cand)
hidden.append(h_pre.flatten())
cell.append(c_pre.flatten())
gate.append(g_pre.flatten())

hidden = np.array(hidden).astype('float64')
cell = np.array(cell).astype('float64')
gate = np.array(gate).astype('float64')

hidden = _reverse(hidden, offset) if is_reverse else hidden
cell = _reverse(cell, offset) if is_reverse else cell

assert gate.shape == input.shape
assert hidden.shape == (input.shape[0], input.shape[1] / 4)
assert cell.shape == (input.shape[0], input.shape[1] / 4)
return hidden, cell, gate
return hidden, cell


class TestLstmOp(OpTest):
def set_argument(self):
self.lod = [[0, 2, 6, 9]]
self.lod = [[0, 2, 6]]
self.D = 16
self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]

self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
Expand All @@ -141,22 +135,18 @@ def setUp(self):

w_b = b[:, 0:4 * self.D]
w_c = b[:, 4 * self.D:]
h, c, g = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse,
ACTVATION[self.act_gate], ACTVATION[self.act_cell],
ACTVATION[self.act_cand])

g_sort = np.zeros_like(x)
for i, j in enumerate(self.sort_idx):
g_sort[i, :] = g[j, :]
h, c = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse,
ACTVATION[self.act_gate], ACTVATION[self.act_cell],
ACTVATION[self.act_cand])

self.inputs = {'Input': (x, self.lod), 'Weight': w, 'Bias': b}
self.inputs['H0'] = h0
self.inputs['C0'] = c0
if self.has_initial_state:
self.inputs['H0'] = h0
self.inputs['C0'] = c0

self.outputs = {
'Hidden': (h, self.lod),
'Cell': (c, self.lod),
'BatchGate': g_sort,
}
self.attrs = {
'usePeepholes': True,
Expand All @@ -179,9 +169,8 @@ def test_check_grad(self):

class TestLstmOpHasNoInitial(TestLstmOp):
def set_argument(self):
self.lod = [[0, 2, 6, 9]]
self.D = 64
self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
self.lod = [[0, 2, 6]]
self.D = 16

self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
Expand All @@ -193,9 +182,8 @@ def set_argument(self):

class TestLstmOpRerverse(TestLstmOp):
def set_argument(self):
self.lod = [[0, 2, 6, 9]]
self.D = 64
self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
self.lod = [[0, 2, 6]]
self.D = 16

self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
Expand Down

0 comments on commit bd680f1

Please sign in to comment.