Skip to content

Commit

Permalink
Fix index.Tensor, index_put batching rules (#862)
Browse files Browse the repository at this point in the history
Fixes #859

Start reading at `NOTE: [advanced indexing (index.Tensor) batch rule]`
in the code for details. This PR rewrites the index.Tensor and index_put
batching rules.

The TL;DR is:
- advanced indexing has different behavior depending on if the "advanced
indices are adjacent":
https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
- we have to take this into account in our batching rules, because
index.Tensor and index_put handle these internally.

Test Plan
- I added new test cases for getitem and aten.ops.index_put via OpInfo
testing.

Future
- primtorch should have a sane decomposition that we can use
- We haven't fixed the index_put_ batching rule yet. TODO later...
- Upstream our test cases (see next section) into pytorch/pytorch
  • Loading branch information
zou3519 committed Jun 13, 2022
1 parent 1a8b86c commit fac1d44
Show file tree
Hide file tree
Showing 3 changed files with 383 additions and 62 deletions.
Loading

0 comments on commit fac1d44

Please sign in to comment.