Skip to content

Commit

Permalink
Allow non-batched initial recurrent states for RecurrentNetworkOp
Browse files Browse the repository at this point in the history
Summary: title

Reviewed By: salexspb

Differential Revision: D4493728

fbshipit-source-id: a9ba25bd325b413ed15c35754afb9ed562b1a60c
  • Loading branch information
urikz authored and facebook-github-bot committed Feb 6, 2017
1 parent 947e5fe commit 280718b
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 19 deletions.
58 changes: 45 additions & 13 deletions caffe2/operators/recurrent_network_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,32 @@ void initializeRecurrentInput(
auto inputBlob = ws->GetBlob(rc.input);
CAFFE_ENFORCE(inputBlob);
const auto& input = inputBlob->template Get<Tensor<Context>>();
CAFFE_ENFORCE_EQ(input.ndim(), 3, rc.input);
CAFFE_ENFORCE_EQ(input.dim(0), 1, rc.input);
CAFFE_ENFORCE_EQ(input.dim(1), batchSize, rc.input);
CAFFE_ENFORCE_GE(input.ndim(), 1, rc.input);
CAFFE_ENFORCE_LE(input.ndim(), 3, rc.input);

const auto stateSize = input.dim(input.ndim() - 1);
// States at [0, ..., T] (inclusive)
state->Resize(seqLen + 1, batchSize, input.dim(2));
context->template Copy<T, Context, Context>(
batchSize * input.dim(2),
input.template data<T>(),
state->template mutable_data<T>());
state->Resize(seqLen + 1, batchSize, stateSize);

if (input.ndim() == 3) {
CAFFE_ENFORCE_EQ(input.dim(0), 1, rc.input);
}
if (input.ndim() >= 2) {
CAFFE_ENFORCE_EQ(input.dim(input.ndim() - 2), batchSize, rc.input);
context->template Copy<T, Context, Context>(
batchSize * stateSize,
input.template data<T>(),
state->template mutable_data<T>());
} else {
for (int i = 0; i < batchSize; ++i) {
// Usually, the initial state is the same for all inputs in the batch.
// So the op conveniently accepts 1-D input and copies it batchSize times.
context->template Copy<T, Context, Context>(
stateSize,
input.template data<T>(),
state->template mutable_data<T>() + i * stateSize);
}
}
}

template <typename T, typename Context>
Expand Down Expand Up @@ -487,14 +503,30 @@ class RecurrentNetworkGradientOp final : public Operator<Context> {
VLOG(1) << "Resetting output " << def().output(outputIdx)
<< " like input " << def().input(inputId);
Output(outputIdx)->ResizeLike(Input(inputId));

auto pBlob = sharedWs_->GetBlob(recurrentGradients_[i].grad);
CAFFE_ENFORCE(pBlob);
auto* p = pBlob->template GetMutable<Tensor<Context>>();
context_.template Copy<T, Context, Context>(
Output(outputIdx)->size(),
p->template data<T>(),
Output(outputIdx)->template mutable_data<T>());

if (Input(inputId).ndim() >= 2) {
context_.template Copy<T, Context, Context>(
Output(outputIdx)->size(),
p->template data<T>(),
Output(outputIdx)->template mutable_data<T>());
} else {
const auto recurrentStateSize = Input(inputId).dim32(0);
context_.template Copy<T, Context, Context>(
recurrentStateSize,
p->template data<T>(),
Output(outputIdx)->template mutable_data<T>());
for (int j = 1; j < batchSize; ++j) {
math::Add<T, Context>(
recurrentStateSize,
p->template data<T>() + j * recurrentStateSize,
Output(outputIdx)->template data<T>(),
Output(outputIdx)->template mutable_data<T>(),
&context_);
}
}
}

return true;
Expand Down
7 changes: 6 additions & 1 deletion caffe2/python/hypothesis_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,12 @@ def softsign(X):
if atol is None:
atol = threshold
np.testing.assert_allclose(
output, ref, atol=atol, rtol=threshold)
output, ref, atol=atol, rtol=threshold,
err_msg=(
'Output {0} is not matching the reference'.format(
output_blob_name,
)),
)
outs.append(output)
if grad_reference and output_to_grad:
self._assertGradReferenceChecks(
Expand Down
17 changes: 12 additions & 5 deletions caffe2/python/operator_test/recurrent_network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def lstm_reference(input, hidden_input, cell_input,
T = input.shape[0]
N = input.shape[1]
G = input.shape[2]
D = hidden_input.shape[2]
D = hidden_input.shape[hidden_input.ndim - 1]
hidden = np.zeros(shape=(T + 1, N, D))
cell = np.zeros(shape=(T + 1, N, D))
assert hidden.shape[0] == T + 1
Expand Down Expand Up @@ -145,12 +145,19 @@ def lstm(self, model, create_lstm, t, n, d, ref, gradients_to_check,
workspace.RunNetOnce(model.param_init_net)
input_blob = op.input[0]

def generate_random_state(n, d):
ndim = int(np.random.choice(3, 1)) + 1
if ndim == 1:
return np.random.randn(1, n, d).astype(np.float32)
random_state = np.random.randn(n, d).astype(np.float32)
if ndim == 3:
random_state = random_state.reshape([1, n, d])
return random_state

workspace.FeedBlob(
str(input_blob), np.random.randn(t, n, d * 4).astype(np.float32))
workspace.FeedBlob(
"hidden_init", np.random.randn(1, n, d).astype(np.float32))
workspace.FeedBlob(
"cell_init", np.random.randn(1, n, d).astype(np.float32))
workspace.FeedBlob("hidden_init", generate_random_state(n, d))
workspace.FeedBlob("cell_init", generate_random_state(n, d))
workspace.FeedBlob(
"seq_lengths", np.random.randint(0, t, size=(n,)).astype(np.int32))
inputs = [workspace.FetchBlob(name) for name in op.input]
Expand Down

0 comments on commit 280718b

Please sign in to comment.