Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX]: Support jax.lax.select_n operation for JAX #28025

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

11happy
Copy link
Contributor

@11happy 11happy commented Dec 11, 2024

Overview:
This pull request fixes #26570.

Testing:

  • Tested the updated code.
  • Verified that other functionalities remain unaffected.
    Screenshot from 2024-12-12 00-09-17

Dependencies:

  • No dependencies on other pull requests.

CC:

@11happy 11happy requested review from a team as code owners December 11, 2024 18:40
@github-actions github-actions bot added category: TF FE OpenVINO TensorFlow FrontEnd category: JAX FE OpenVINO JAX FrontEnd labels Dec 11, 2024
@sys-openvino-ci sys-openvino-ci added the ExternalPR External contributor label Dec 11, 2024
@rkazants
Copy link
Contributor

build_jenkins

@11happy 11happy requested a review from rkazants December 13, 2024 04:43
@11happy
Copy link
Contributor Author

11happy commented Dec 14, 2024

humble ping! @rkazants

@11happy
Copy link
Contributor Author

11happy commented Dec 16, 2024

Hello @rkazants, please let me know if there's anything needed from my end.

Comment on lines 23 to 36
OutputVector cases_vector(num_inputs - 1);
for(int ind = 1; ind < num_inputs; ++ind) {
cases_vector[ind - 1] = context.get_input(ind);
}

Output<Node> cases = std::make_shared<v0::Concat>(cases_vector, 0);
auto which_shape = which.get_shape();
std::vector<int64_t> cases_reshape_shape = {num_inputs-1,which_shape[0]};
std::vector<int64_t> which_reshape_shape = {1,which_shape[0]};

cases = std::make_shared<v1::Reshape>(cases, ov::op::v0::Constant::create(element::i64, Shape{2}, cases_reshape_shape), false);
which = std::make_shared<v1::Reshape>(which, ov::op::v0::Constant::create(element::i64, Shape{2}, which_reshape_shape), false);
Output<Node> result = std::make_shared<v6::GatherElements>(cases, which, 0);
return {result};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current solution has limitations, unfortunately: it uses get_shape that only works for static shapes in a graph, so not reshapeable decomposition.
let us implement this as follows:

  1. unsqueeze all cases by axis=0 and concat them by axis=0
  2. compute shape of one of case (original) and broadcast which to this
  3. use GatherElements for concated cases and broadcasted which and use axis=0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done



@pytest.mark.parametrize("input_shape", [1,2,3,4,5,6,7,8,9,10])
@pytest.mark.parametrize("input_type", [np.int32, np.int64, np.bool, bool])
Copy link
Contributor

@rkazants rkazants Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy deprecated and removed np.bool type so we need to use only bool

Suggested change
@pytest.mark.parametrize("input_type", [np.int32, np.int64, np.bool, bool])
@pytest.mark.parametrize("input_type", [np.int32, np.int64, bool])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return jax_select_n, None, 'select_n'


@pytest.mark.parametrize("input_shape", [1,2,3,4,5,6,7,8,9,10])
Copy link
Contributor

@rkazants rkazants Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let us have separately shapes for which and cases to check broadcasting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while I was trying to do this I got this error:
TypeError: select which must be scalar or have the same shape as cases

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax.lax.select_n does not allow for different shapes

@11happy 11happy requested a review from rkazants December 19, 2024 17:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: JAX FE OpenVINO JAX FrontEnd category: TF FE OpenVINO TensorFlow FrontEnd ExternalPR External contributor
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Good First Issue][JAX FE]: Support jax.lax.select_n operation for JAX
3 participants