diff --git a/caffe2/operators/recurrent_network_op.h b/caffe2/operators/recurrent_network_op.h index 372a69d80b555..07d8f818d3744 100644 --- a/caffe2/operators/recurrent_network_op.h +++ b/caffe2/operators/recurrent_network_op.h @@ -81,16 +81,32 @@ void initializeRecurrentInput( auto inputBlob = ws->GetBlob(rc.input); CAFFE_ENFORCE(inputBlob); const auto& input = inputBlob->template Get>(); - CAFFE_ENFORCE_EQ(input.ndim(), 3, rc.input); - CAFFE_ENFORCE_EQ(input.dim(0), 1, rc.input); - CAFFE_ENFORCE_EQ(input.dim(1), batchSize, rc.input); + CAFFE_ENFORCE_GE(input.ndim(), 1, rc.input); + CAFFE_ENFORCE_LE(input.ndim(), 3, rc.input); + const auto stateSize = input.dim(input.ndim() - 1); // States at [0, ..., T] (inclusive) - state->Resize(seqLen + 1, batchSize, input.dim(2)); - context->template Copy( - batchSize * input.dim(2), - input.template data(), - state->template mutable_data()); + state->Resize(seqLen + 1, batchSize, stateSize); + + if (input.ndim() == 3) { + CAFFE_ENFORCE_EQ(input.dim(0), 1, rc.input); + } + if (input.ndim() >= 2) { + CAFFE_ENFORCE_EQ(input.dim(input.ndim() - 2), batchSize, rc.input); + context->template Copy( + batchSize * stateSize, + input.template data(), + state->template mutable_data()); + } else { + for (int i = 0; i < batchSize; ++i) { + // Usually, the initial state is the same for all inputs in the batch. + // So the op conveniently accepts 1-D input and copies it batchSize times. + context->template Copy( + stateSize, + input.template data(), + state->template mutable_data() + i * stateSize); + } + } } template @@ -487,14 +503,30 @@ class RecurrentNetworkGradientOp final : public Operator { VLOG(1) << "Resetting output " << def().output(outputIdx) << " like input " << def().input(inputId); Output(outputIdx)->ResizeLike(Input(inputId)); - auto pBlob = sharedWs_->GetBlob(recurrentGradients_[i].grad); CAFFE_ENFORCE(pBlob); auto* p = pBlob->template GetMutable>(); - context_.template Copy( - Output(outputIdx)->size(), - p->template data(), - Output(outputIdx)->template mutable_data()); + + if (Input(inputId).ndim() >= 2) { + context_.template Copy( + Output(outputIdx)->size(), + p->template data(), + Output(outputIdx)->template mutable_data()); + } else { + const auto recurrentStateSize = Input(inputId).dim32(0); + context_.template Copy( + recurrentStateSize, + p->template data(), + Output(outputIdx)->template mutable_data()); + for (int j = 1; j < batchSize; ++j) { + math::Add( + recurrentStateSize, + p->template data() + j * recurrentStateSize, + Output(outputIdx)->template data(), + Output(outputIdx)->template mutable_data(), + &context_); + } + } } return true; diff --git a/caffe2/python/hypothesis_test_util.py b/caffe2/python/hypothesis_test_util.py index 79f06d95973da..cda30b49f6d93 100644 --- a/caffe2/python/hypothesis_test_util.py +++ b/caffe2/python/hypothesis_test_util.py @@ -456,7 +456,12 @@ def softsign(X): if atol is None: atol = threshold np.testing.assert_allclose( - output, ref, atol=atol, rtol=threshold) + output, ref, atol=atol, rtol=threshold, + err_msg=( + 'Output {0} is not matching the reference'.format( + output_blob_name, + )), + ) outs.append(output) if grad_reference and output_to_grad: self._assertGradReferenceChecks( diff --git a/caffe2/python/operator_test/recurrent_network_test.py b/caffe2/python/operator_test/recurrent_network_test.py index 9b8bcff5bd1ec..a247bf8376ca8 100644 --- a/caffe2/python/operator_test/recurrent_network_test.py +++ b/caffe2/python/operator_test/recurrent_network_test.py @@ -58,7 +58,7 @@ def lstm_reference(input, hidden_input, cell_input, T = input.shape[0] N = input.shape[1] G = input.shape[2] - D = hidden_input.shape[2] + D = hidden_input.shape[hidden_input.ndim - 1] hidden = np.zeros(shape=(T + 1, N, D)) cell = np.zeros(shape=(T + 1, N, D)) assert hidden.shape[0] == T + 1 @@ -145,12 +145,19 @@ def lstm(self, model, create_lstm, t, n, d, ref, gradients_to_check, workspace.RunNetOnce(model.param_init_net) input_blob = op.input[0] + def generate_random_state(n, d): + ndim = int(np.random.choice(3, 1)) + 1 + if ndim == 1: + return np.random.randn(1, n, d).astype(np.float32) + random_state = np.random.randn(n, d).astype(np.float32) + if ndim == 3: + random_state = random_state.reshape([1, n, d]) + return random_state + workspace.FeedBlob( str(input_blob), np.random.randn(t, n, d * 4).astype(np.float32)) - workspace.FeedBlob( - "hidden_init", np.random.randn(1, n, d).astype(np.float32)) - workspace.FeedBlob( - "cell_init", np.random.randn(1, n, d).astype(np.float32)) + workspace.FeedBlob("hidden_init", generate_random_state(n, d)) + workspace.FeedBlob("cell_init", generate_random_state(n, d)) workspace.FeedBlob( "seq_lengths", np.random.randint(0, t, size=(n,)).astype(np.int32)) inputs = [workspace.FetchBlob(name) for name in op.input]