-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX]: Support jax.lax.select_n operation for JAX #28025
base: master
Are you sure you want to change the base?
Conversation
Signed-off-by: 11happy <[email protected]>
build_jenkins |
Signed-off-by: 11happy <[email protected]>
humble ping! @rkazants |
Hello @rkazants, please let me know if there's anything needed from my end. |
OutputVector cases_vector(num_inputs - 1); | ||
for(int ind = 1; ind < num_inputs; ++ind) { | ||
cases_vector[ind - 1] = context.get_input(ind); | ||
} | ||
|
||
Output<Node> cases = std::make_shared<v0::Concat>(cases_vector, 0); | ||
auto which_shape = which.get_shape(); | ||
std::vector<int64_t> cases_reshape_shape = {num_inputs-1,which_shape[0]}; | ||
std::vector<int64_t> which_reshape_shape = {1,which_shape[0]}; | ||
|
||
cases = std::make_shared<v1::Reshape>(cases, ov::op::v0::Constant::create(element::i64, Shape{2}, cases_reshape_shape), false); | ||
which = std::make_shared<v1::Reshape>(which, ov::op::v0::Constant::create(element::i64, Shape{2}, which_reshape_shape), false); | ||
Output<Node> result = std::make_shared<v6::GatherElements>(cases, which, 0); | ||
return {result}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Current solution has limitations, unfortunately: it uses get_shape
that only works for static shapes in a graph, so not reshapeable decomposition.
let us implement this as follows:
- unsqueeze all cases by axis=0 and concat them by axis=0
- compute shape of one of case (original) and broadcast which to this
- use GatherElements for concated cases and broadcasted which and use axis=0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
|
||
@pytest.mark.parametrize("input_shape", [1,2,3,4,5,6,7,8,9,10]) | ||
@pytest.mark.parametrize("input_type", [np.int32, np.int64, np.bool, bool]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numpy deprecated and removed np.bool
type so we need to use only bool
@pytest.mark.parametrize("input_type", [np.int32, np.int64, np.bool, bool]) | |
@pytest.mark.parametrize("input_type", [np.int32, np.int64, bool]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
return jax_select_n, None, 'select_n' | ||
|
||
|
||
@pytest.mark.parametrize("input_shape", [1,2,3,4,5,6,7,8,9,10]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let us have separately shapes for which
and cases
to check broadcasting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
while I was trying to do this I got this error:
TypeError: select
which must be scalar or have the same shape as cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax.lax.select_n
does not allow for different shapes
Signed-off-by: 11happy <[email protected]>
Overview:
This pull request fixes #26570.
Testing:
Dependencies:
CC: