Skip to content

Commit

Permalink
ensure ids in lookup table op must be a column vector (#4987)
Browse files Browse the repository at this point in the history
* ensure ids in lookup table op must be a column vector

* follow comments
  • Loading branch information
QiJune authored and reyoung committed Oct 23, 2017
1 parent 7d653c4 commit 40e7caf
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
7 changes: 6 additions & 1 deletion paddle/operators/lookup_table_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class LookupTableOp : public framework::OperatorWithKernel {
auto table_dims = ctx->GetInputDim("W");
auto ids_dims = ctx->GetInputDim("Ids");

PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);

ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
ctx->ShareLoD("Ids", /*->*/ "Out");
}
Expand All @@ -53,7 +56,9 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
" which is a learnable parameter.");
AddInput("Ids",
"An input with type int32 or int64"
"contains the ids to be looked up in W.");
"contains the ids to be looked up in W."
"Ids must be a column vector with rank = 2."
"The 2nd dimension size must be 1");
AddOutput("Out", "The lookup results, which have the same type with W.");
AddComment(R"DOC(
This operator is used to perform lookups on the parameter W,
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/v2/framework/tests/test_lookup_table_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ def setUp(self):
self.op_type = "lookup_table"
table = np.random.random((17, 31)).astype("float32")
ids = np.random.randint(0, 17, 4).astype("int32")
self.inputs = {'W': table, 'Ids': ids}
ids_expand = np.expand_dims(ids, axis=1)
self.inputs = {'W': table, 'Ids': ids_expand}
self.outputs = {'Out': table[ids]}

def test_check_output(self):
Expand Down

0 comments on commit 40e7caf

Please sign in to comment.