-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhance look_up_table op #8932
Enhance look_up_table op #8932
Conversation
0c7b04b
to
5f1b8eb
Compare
5f1b8eb
to
f1c3ecb
Compare
@@ -85,6 +91,44 @@ or not. And the output only shares the LoD information with input Ids. | |||
} | |||
}; | |||
|
|||
class ConcatRowsOpMaker : public framework::OpProtoAndCheckerMaker { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this name is not so good and maybe we just need a python layer but not a new operator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your review @jacquesqiao
The parameters' meaning of ConcatRowsOp are different with LookUpTableOp, and the two operations' function is also different, so I define the new operation(ConcatRows). And I have added some annotation in the code.
Is your meaning that ConcatRowsOp is unnecessary?
9140f73
to
b9397b2
Compare
5de8215
to
1f757f5
Compare
… feature/add_concat_rows
c9072d6
to
94e43e9
Compare
94e43e9
to
92e2207
Compare
W_array[i] *= i | ||
W.set(W_array, place) | ||
|
||
Out = scope.var('Out').get_selected_rows() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remote these lines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
result_array = np.array(Out_tensor) | ||
|
||
for idx, row in enumerate(rows): | ||
assert (row == result_array[idx]).all() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add a comments about this usage of numpy.array
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
fix #8933