-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Frontend][TensorFlow] Support BatchMatMul with input dimensions larger than 3 #3732
Conversation
2f6066c
to
89a9a7b
Compare
@srkreddy1238 and @alexeyr would you mind taking a look? |
outer_dims = [orig_shape_x[i] for i in range(0, len(orig_shape_x) - 2)] | ||
num_outer_elts = 1 | ||
for outer_dim in outer_dims: | ||
num_outer_elts *= outer_dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this just call np.prod
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
_test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), 'int32') | ||
_test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), 'float32', True, True) | ||
_test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 5, 6), 'int32', True, False) | ||
_test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 5, 6), 'float32', False, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer to have tests with different numbers of outer dimensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
@soiferj how hard is it to make dense/batchmatmul to support more dimension instead? |
From looking around the code, it seems like a good amount of work. All of computes would have to change, and we would probably want to change all of the schedules as well. We would also have to find and remove all asserts that verify the number of dimensions. Is this a preferred solution? I don't think the reshapes will have much overhead. |
I see. Can you make this function universal then? (move this into relay/op, and allow other ppl to use it as well). I need this function badly. |
Sure! I can work on that. |
@alexeyr what do you think of the proposed change? Should we make this functionality a little more generic? If so, where do you think it should go? Edit: never mind, I’m not going to make this change. Feel free to review as-is. |
@MarisaKirisame would you mind explaining what you want the interface to look like? Because I don’t know how all of the front ends look, I could make a method called ConvertTo3d. Then, converting back to the original dimensions would be handled by the front end (since it’s just one function call). |
@soiferj nevermind, I dont need it anymore, this is good to me. |
@soiferj I'd reply that I can't evaluate at the moment anyway :) But it looks good to me. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@tqchen @tmoreau89 would one of you be able to merge? |
Thanks @soiferj @alexeyr @MarisaKirisame ! |
…ns larger than 3 (apache#3732) * Support BatchMatMul with shapes greater than length 3 * Fixes * Add tests * Remove dependency on Python3 * Clean up * Merge with master * Resolve comments
…ns larger than 3 (apache#3732) * Support BatchMatMul with shapes greater than length 3 * Fixes * Add tests * Remove dependency on Python3 * Clean up * Merge with master * Resolve comments
…ns larger than 3 (apache#3732) * Support BatchMatMul with shapes greater than length 3 * Fixes * Add tests * Remove dependency on Python3 * Clean up * Merge with master * Resolve comments
This change adds support for BatchMatMul with n-dimensional inputs. For example, batch matrix multiplication of (2, 3, 4, 5) x (2, 3, 5, 4).