-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(ops): support np.argpartition #19588
Conversation
Note |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #19588 +/- ##
=======================================
Coverage 78.28% 78.29%
=======================================
Files 498 498
Lines 45382 45454 +72
Branches 8362 8373 +11
=======================================
+ Hits 35528 35588 +60
- Misses 8101 8107 +6
- Partials 1753 1759 +6
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
keras/src/backend/torch/numpy.py
Outdated
|
||
for _ in range(x.dim() - 1): | ||
set_to_zero = torch.vmap(set_to_zero) | ||
proxy = set_to_zero(torch.ones(x.shape), bottom_ind) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should ones be int32?
keras/src/ops/numpy.py
Outdated
|
||
@keras_export(["keras.ops.argpartition", "keras.ops.numpy.argpartition"]) | ||
def argpartition(x, kth, axis=-1): | ||
"""Performs an indirect partition along the given axis. It returns an array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add 2 line breaks after "the given axis."
keras/src/ops/numpy.py
Outdated
If provided with a sequence of k-th it will partition all of them | ||
into their sorted position at once. | ||
axis: Axis along which to sort. The default is -1 (the last axis). | ||
If None, the flattened array is used. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Backticks around None
def argpartition(x, kth, axis=-1): | ||
x = convert_to_tensor(x) | ||
|
||
if standardize_dtype(x.dtype) not in ["int32", "int64"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use x.dtype.is_integer
if standardize_dtype(x.dtype) not in ["int32", "int64"]: | ||
x = tf.cast(x, tf.int32) | ||
|
||
x = tf.experimental.numpy.swapaxes(x, axis, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use swapaxes
(defined in this very file)
|
||
def set_to_zero(args): | ||
a, i = args | ||
updates = tf.reshape(tf.zeros_like(i, dtype=tf.float32), [-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why float32 here?
top_ind = tf.math.top_k(proxy, tf.shape(x)[-1] - kth - 1)[1] | ||
|
||
out = tf.concat([bottom_ind, top_ind], axis=x.ndim - 1) | ||
return tf.experimental.numpy.swapaxes(out, -1, axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likewise, use swapaxes
set_to_zero = tf.vectorized_map( | ||
set_to_zero, (tf.ones(tf.shape(x)), bottom_ind) | ||
) | ||
proxy = set_to_zero((tf.ones(tf.shape(x)), bottom_ind)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This appears to be reimplementing the corresponding part of the JAX code that uses vmap
. But how did you come up with this exactly? I can't quite follow the logic.
@fchollet |
@fchollet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates! Please also format the code by running sh shell/format.sh
, and edit API files by running sh shell/api_gen.sh
.
keras/src/backend/torch/numpy.py
Outdated
|
||
def argpartition(x, kth, axis=-1): | ||
x = convert_to_tensor(x) | ||
x = cast(x, "int32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just do x = convert_to_tensor(x, dtype="int32")
instead of a separate cast call.
keras/src/ops/numpy.py
Outdated
|
||
def compute_output_spec(self, x): | ||
if not isinstance(self.kth, int): | ||
raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this to the constructor.
keras/src/ops/numpy.py
Outdated
) | ||
|
||
dtype = "int32" | ||
if backend.backend() in ("torch"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code will never be visited since the tuple is missing a comma.
keras/src/ops/numpy.py
Outdated
|
||
dtype = "int32" | ||
if backend.backend() in ("torch"): | ||
dtype = "int64" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not have backend-dependent behaviors, so all backends should return the same type. We need to cast in backend ops if that's not the case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The torch unit tests fail -- it looks like we need to cast outputs to int32 for torch. https://github.com/keras-team/keras/actions/runs/8853599761/job/24314819196?pr=19588
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the contribution! LGTM
Torch tests seem to fail on GPU with device placement issues: https://btx.cloud.google.com/invocations/43f688ba-e1b0-4100-8725-1b7756581e52/targets/keras%2Fgithub%2Fubuntu%2Fgpu%2Ftorch%2Fcontinuous/log |
I think in the line |
I fixed it. The issue was the use of |
@fchollet |
* Introduce float8 training (#19488) * Add float8 training support * Add tests for fp8 training * Add `quantize_and_dequantize` test * Fix bugs and add float8 correctness tests * Cleanup * Address comments and cleanup * Add docstrings and some minor refactoring * Add `QuantizedFloat8DTypePolicy` * Add dtype policy setter * Fix torch dynamo issue by using `self._dtype_policy` * Improve test coverage * Add LoRA to ConvND layers (#19516) * Add LoRA to `BaseConv` * Add tests * Fix typo * Fix tests * Fix tests * Add path to run keras on dm-tree when optree is not available. * feat(losses): add Tversky loss implementation (#19511) * feat(losses): add Tversky loss implementation * adjusted documentation * Update KLD docs * Models and layers now return owned metrics recursively. (#19522) - added `Layer.metrics` to return all metrics owned by the layer and its sub-layers recursively. - `Layer.metrics_variables` now returns variables from all metrics recursively, not just the layer and its direct sub-layers. - `Model.metrics` now returns all metrics recursively, not just the model level metrics. - `Model.metrics_variables` now returns variables from all metrics recursively, not just the model level metrics. - added test coverage to test metrics and variables 2 levels deep. This is consistent with the Keras 2 behavior and how `Model/Layer.variables` and `Model/Layer.weights` work. * Update IoU ignore_class handling * Fix `RandomBrightness`, Enhance `IndexLookup` Initialization and Expand Test Coverage for `Preprocessing Layers` (#19513) * Add tests for CategoryEncoding class in category_encoding_test.py * fix * Fix IndexLookup class initialization and add test cases * Add test case for IndexLookupLayerTest without vocabulary * Fix IndexLookup class initialization * Add normalization test cases * Add test cases for Hashing class * Fix value range validation error in RandomBrightness class * Refactor IndexLookup class initialization and add test cases * Reffix ndexLookup class initialization and afix est cases * Add test for spectral norm * Add missing test decorator * Fix torch test * Fix code format * Generate API (#19530) * API Generator for Keras * API Generator for Keras * Generates API Gen via api_gen.sh * Remove recursive import of _tf_keras * Generate API Files via api_gen.sh * Update APIs * Added metrics from custom `train_step`/`test_step` are now returned. (#19529) This works the same way as in Keras 2, whereby the metrics are returned directly from the logs if the set of keys doesn't match the model metrics. * Use temp dir and abs path in `api_gen.py` (#19533) * Use temp dir and abs path * Use temp dir and abs path * Update Readme * Update API * Fix gradient accumulation when using `overwrite_with_gradient` during float8 training (#19534) * Fix gradient accumulation with `overwrite_with_gradient` in float8 training * Add comments * Fix annotation * Update code path in ignore path (#19537) * Add operations per run (#19538) * Include input shapes in model visualization. * Add pad_to_aspect_ratio feature in ops.image.resize * Add pad_to_aspect_ratio feature in Resizing layer. * Fix incorrect usage of `quantize` (#19541) * Add logic to prevent double quantization * Add detailed info for double quantization error * Update error msg * Add eigh op. * Add keepdim in argmax/argmin. * Fix small bug in model.save_weights (#19545) * Update public APIs. * eigh should work on JAX GPU * Copy init to keras/__init__.py (#19551) * Revert "Copy init to keras/__init__.py (#19551)" (#19552) This reverts commit da9af61. * sum-reduce inlined losses * Remove the dependency on `tensorflow.experimental.numpy` and support negative indices for `take` and `take_along_axis` (#19556) * Remove `tfnp` * Update numpy api * Improve test coverage * Improve test coverage * Fix `Tri` and `Eye` and increase test converage * Update `round` test * Fix `jnp.round` * Fix `diag` bug for iou_metrics * Add op.select. * Add new API for select * Make `ops.abs` and `ops.absolute` consistent between backends. (#19563) - The TensorFlow implementation was missing `convert_to_tensor` - The sparse annotation was unnecessarily applied twice - Now `abs` calls `absolute` in all backends Also fixed TensorFlow `ops.select`. * Add pickle support for Keras model (#19555) * Implement unit tests for pickling * Reformat model_test * Reformat model_test * Rename depickle to unpickle * Rename depickle to unpickle * Reformat * remove a comment * Ellipsis Serialization and tests (#19564) * Serialization and tests * Serialization and tests * Serialization and tests * Make TF one_hot input dtype less strict. * Fix einsum `_int8_call` (#19570) * CTC Decoding for JAX and Tensorflow (#19366) * Tensorflow OP for CTC decoding * JAX op for CTC greedy decoding * Update CTC decoding documentation * Fix linting issues * Fix trailing whitespace * Simplify returns in tensorflow CTC wrapper * Fix CTC decoding error messages * Fix line too long * Bug fixes to JAX CTC greedy decoder * Force int typecast in TF CTC decoder * Unit tests for CTC greedy decoding * Add unit test for CTC beam search decoding * Fix mask index set location in JAX CTC decoding * CTC beam search decoding for JAX * Fix unhandled token repetitions in ctc_beam_search_decode * Fix merge_repeated bug in CTC beam search decode * Fix beam storage and repetition bugs in JAX ctc_decode * Remove trailing whitespace * Fix ordering bug for ties in JAX CTC beam search * Cast sequence lengths to integers in JAX ctc_decode * Remove line break in docstring * CTC beam search decoding for JAX * Fix unhandled token repetitions in ctc_beam_search_decode * Fix merge_repeated bug in CTC beam search decode * Fix beam storage and repetition bugs in JAX ctc_decode * Fix ordering bug for ties in JAX CTC beam search * Generate public api directory * Add not implemented errors for NumPy and Torch CTC decoding * Remove unused redefinition of JAX ctc_beam_search_decode * Docstring edits * Expand nan_to_num args. * Add vectorize op. * list insert requires index (#19575) * Add signature and exclude args to knp.vectorize. * Fix the apis of `dtype_polices` (#19580) * Fix api of `dtype_polices` * Update docstring * Increase test coverage * Fix format * Fix keys of `save_own_variables` and `load_own_variables` (#19581) * Fix JAX CTC test. * Fix loss_weights handling in single output case * Fix JAX vectorize. * Move _tf_keras directory to the root of the pip package. * One time fix to _tf_keras API. * Convert return type imdb.load_data to nparray (#19598) Convert return type imdb.load_data to Numpy array. Currently X_train and X-test returned as list. * Fix typo * fix api_gen.py for legacy (#19590) * fix api_gen.py for legacy * merge api and legacy for _tf_keras * Improve int8 for `Embedding` (#19595) * pin torch < 2.3.0 (#19603) * Clean up duplicated `inputs_quantizer` (#19604) * Cleanup duplicated `inputs_quantizer` and add type check for `input_spec` and `supports_masking` * Revert setter * output format changes and errors in github (#19608) * Provide write permission to action for cache management. (#19606) * Pickle support for all saveables (#19592) * Pickle support * Add keras pickleable mixin * Reformat * Implement pickle all over * reformat * Reformat * Keras saveable * Keras saveable * Keras saveable * Keras saveable * Keras saveable * obj_type * Update pickleable * Saveable logic touchups * Add slogdet op. * Update APIs * Remove unused import * Refactor CTC APIs (#19611) * Add `ctc_loss` and `ctc_decode` for numpy backend, improve imports and tests * Support "beam_search" strategy for torch's `ctc_decode` * Improve `ctc_loss` * Cleanup * Refactor `ctc_decode` * Update docstring * Update docstring * Add `CTCDecode` operation and ensure dtype inference of `ctc_decode` * Fix `name` of `losses.CTC` * update the namex version requirements (#19617) * Add `PSNR` API (#19616) * PSNR * Fix * Docstring format * Remove `PYTORCH_ENABLE_MPS_FALLBACK` flag requirement for mps (#19618) * Remove `PYTORCH_ENABLE_MPS_FALLBACK` flag requirement for mps * Formatting * Implement custom layer insertion in clone_model. (#19610) * Implement custom layer insertion in clone_model. * Add recursive arg and tests. * Add nested sequential cloning test * Fix bidir lstm saving issue. * Fix CI * Fix cholesky tracing with jax * made extract_patches dtype agnostic (#19621) * Simplify Bidirectional implementation * Add support for infinite `PyDataset`s. (#19624) `PyDataset` now uses the `num_batches` property instead of `__len__` to support `None`, which is how one indicates the dataset is infinite. Note that infinite datasets are not shuffled. Fixes #19528 Also added exception reporting when using multithreading / multiprocessing. Previously, the program would just hang with no error reported. * Fix dataset shuffling issue. * Update version string. * Minor fix * Restore version string resolution in pip_build. * Speed up `DataAdapter` tests by testing only the current backend. (#19625) There is no use case for using an iterator for a different backend than the current backend. Also: - limit the number of tests using multiprocessing, the threading tests give us good coverage. - fixed the `test_exception_reported` test, which was not actually exercising the multiprocessing / multithreading cases. - removed unused `init_pool` method. * feat(ops): support np.argpartition (#19588) * feat(ops): support np.argpartition * updated documentation, type-casting, and tf implementation * fixed tf implementation * added torch cast to int32 * updated torch type and API generated files * added torch output type cast * test(trainers): add test_errors implementation for ArrayDataAdapter class (#19626) * Fix torch GPU CI * Fix argmax/argmin keepdims with defined axis in TF * Misc fixes in TF backend ops. * Fix `argpartition` cuda bug in torch (#19634) * fix(ops): specify NonZero output dtype and add test coverage (#19635) * Fix `ops.ctc_decode` (#19633) * Fix greedy ctc decode * Remove print * Fix `tf.nn.ctc_beam_search_decoder` * Change default `mask_index` to `0` * Fix losses test * Update * Ensure the same rule applies for np arrays in autocasting (#19636) * Ensure the same rule applies for np arrays in autocasting * Trigger CI by adding docstring * Update * Update docstring * Fix `istft` and add class `TestMathErrors` in `ops/math_test.py` (#19594) * Fix and test math functions for jax backend * run /workspaces/keras/shell/format.sh * refix * fix * fix _get_complex_tensor_from_tuple * fix * refix * Fix istft function to handle inputs with less than 2 dimensions * fix * Fix ValueError in istft function for inputs with less than 2 dimensions * Return a tuple from `ops.shape` with the Torch backend. (#19640) With Torch, `x.shape` returns a `torch.Size`, which is a subclass of `tuple` but can cause different behaviors. In particular `convert_to_tensor` does not work on `torch.Size`. This fixes #18900 * support conv3d on cpu for TF (#19641) * Enable cudnn rnns when dropout is set (#19645) * Enable cudnn rnns when dropout is set * Fix * Fix plot_model for input dicts. * Fix deprecation warning in torch * Bump the github-actions group with 2 updates (#19653) Bumps the github-actions group with 2 updates: [actions/upload-artifact](https://github.com/actions/upload-artifact) and [github/codeql-action](https://github.com/github/codeql-action). Updates `actions/upload-artifact` from 4.3.1 to 4.3.3 - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](actions/upload-artifact@5d5d22a...6546280) Updates `github/codeql-action` from 3.24.9 to 3.25.3 - [Release notes](https://github.com/github/codeql-action/releases) - [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md) - [Commits](github/codeql-action@1b1aada...d39d31e) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-type: direct:production update-type: version-update:semver-patch dependency-group: github-actions - dependency-name: github/codeql-action dependency-type: direct:production update-type: version-update:semver-minor dependency-group: github-actions ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Bump the python group with 2 updates (#19654) Bumps the python group with 2 updates: torch and torchvision. Updates `torch` from 2.2.1+cu121 to 2.3.0+cu121 Updates `torchvision` from 0.17.1+cu121 to 0.18.0+cu121 --- updated-dependencies: - dependency-name: torch dependency-type: direct:production update-type: version-update:semver-minor dependency-group: python - dependency-name: torchvision dependency-type: direct:production update-type: version-update:semver-minor dependency-group: python ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Revert "Bump the python group with 2 updates (#19654)" (#19655) This reverts commit 09133f4. --------- Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: james77777778 <[email protected]> Co-authored-by: Francois Chollet <[email protected]> Co-authored-by: Luca Pizzini <[email protected]> Co-authored-by: hertschuh <[email protected]> Co-authored-by: Faisal Alsrheed <[email protected]> Co-authored-by: Ramesh Sampath <[email protected]> Co-authored-by: Sachin Prasad <[email protected]> Co-authored-by: Uwe Schmidt <[email protected]> Co-authored-by: Luke Wood <[email protected]> Co-authored-by: Maanas Arora <[email protected]> Co-authored-by: AlexanderLavelle <[email protected]> Co-authored-by: Surya <[email protected]> Co-authored-by: Shivam Mishra <[email protected]> Co-authored-by: Haifeng Jin <[email protected]> Co-authored-by: IMvision12 <[email protected]> Co-authored-by: Gabriel Rasskin <[email protected]> Co-authored-by: Vachan V Y <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Adds support for the
numpy.argpartition
operation.