Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix index.Tensor, index_put batching rules (#862)
Fixes #859 Start reading at `NOTE: [advanced indexing (index.Tensor) batch rule]` in the code for details. This PR rewrites the index.Tensor and index_put batching rules. The TL;DR is: - advanced indexing has different behavior depending on if the "advanced indices are adjacent": https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing - we have to take this into account in our batching rules, because index.Tensor and index_put handle these internally. Test Plan - I added new test cases for getitem and aten.ops.index_put via OpInfo testing. Future - primtorch should have a sane decomposition that we can use - We haven't fixed the index_put_ batching rule yet. TODO later... - Upstream our test cases (see next section) into pytorch/pytorch
- Loading branch information