Skip to content

Commit

Permalink
[TE] Support negative indices (apache#9023)
Browse files Browse the repository at this point in the history
* initial change

* more explicit api

* switch to select

* add support for negative indices

* reduce things further

* lint

* to CamelCase

* unit test

Co-authored-by: Andrew Zhao Luo <[email protected]>
  • Loading branch information
AndrewZhaoLuo and Andrew Zhao Luo authored Jan 12, 2022
1 parent 586944e commit 80f9b9f
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 6 deletions.
32 changes: 32 additions & 0 deletions include/tvm/te/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ class TensorNode : public DataProducerNode {
* or intermediate computation result.
*/
class Tensor : public DataProducer {
private:
/*!
* \brief Helper for indexing operations into tensors
* \param indices The indices
* \param support_negative_indices Whether to normalize indices in the case of negative indices.
* \return the result expression representing tensor read.
*/
inline PrimExpr IndexTensor(Array<PrimExpr> indices, bool support_negative_indices) const;

public:
TVM_DLL Tensor(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index);
/*!
Expand Down Expand Up @@ -138,6 +147,29 @@ class Tensor : public DataProducer {
* \return the result expression representing tensor read.
*/
TVM_DLL PrimExpr operator()(Array<Var> indices) const;
/*!
* \brief Take elements from the tensor with support for negative indices.
* \param args The indices
* \return the result expression representing tensor read.
*/
template <typename... Args>
TVM_DLL PrimExpr IndexWithNegativeIndices(Args&&... args) const {
Array<PrimExpr> indices{std::forward<Args>(args)...};
return IndexWithNegativeIndices(indices);
}
/*!
* \brief Take elements from the tensor with support for negative indices.
* \param indices the indices.
* \return the result expression representing tensor read.
*/
TVM_DLL PrimExpr IndexWithNegativeIndices(Array<PrimExpr> indices) const;
/*!
* \brief Take elements from the tensor with support for negative indices.
* \param indices the indices.
* \return the result expression representing tensor read.
*/
TVM_DLL PrimExpr IndexWithNegativeIndices(Array<Var> indices) const;

/*!
* \brief data structure to represent a slice that fixes first k coordinates.
* This is used to enable syntax sugar of Tensor[x][y][z] to get the element.
Expand Down
33 changes: 27 additions & 6 deletions src/te/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,39 @@ IterVar reduce_axis(Range dom, std::string name) { return IterVar(dom, Var(name)
Var var(std::string name_hint, DataType t) { return Var(name_hint, t); }

// Tensor
inline PrimExpr Tensor::IndexTensor(Array<PrimExpr> indices, bool support_negative_indices) const {
Array<PrimExpr> shape = (*this)->shape;

if (shape.size() != 0) {
ICHECK_EQ(shape.size(), indices.size())
<< "Tensor dimension mismatch in read "
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
}

if (support_negative_indices) {
for (size_t i = 0; i < shape.size(); i++) {
PrimExpr new_index =
Select(indices[i] < make_const(indices[i]->dtype, 0), indices[i] + shape[i], indices[i]);
indices.Set(i, new_index);
}
}
return ProducerLoad((*this), indices);
}

PrimExpr Tensor::operator()(Array<Var> indices) const {
Array<PrimExpr> arr(indices.begin(), indices.end());
return operator()(arr);
}

PrimExpr Tensor::operator()(Array<PrimExpr> indices) const {
if (ndim() != 0) {
ICHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read "
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
}
PrimExpr Tensor::operator()(Array<PrimExpr> indices) const { return IndexTensor(indices, false); }

return ProducerLoad((*this), indices);
PrimExpr Tensor::IndexWithNegativeIndices(Array<Var> indices) const {
Array<PrimExpr> arr(indices.begin(), indices.end());
return IndexWithNegativeIndices(arr);
}

PrimExpr Tensor::IndexWithNegativeIndices(Array<PrimExpr> indices) const {
return IndexTensor(indices, true);
}

String TensorNode::GetNameHint() const {
Expand Down
11 changes: 11 additions & 0 deletions tests/cpp/tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,14 @@ TEST(Tensor, Reduce) {
{m, n}, [&](Var i, Var j) { return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv}); }, "C");
LOG(INFO) << C->op.as<te::ComputeOpNode>()->body;
}

TEST(Tensor, Indexing) {
using namespace tvm;
using namespace tvm::te;

Var x("x"), y("y");
te::Tensor A = te::placeholder({x, y}, DataType::Float(32), "A");
LOG(INFO) << A(0, 0);
LOG(INFO) << A.IndexWithNegativeIndices(-1, -1);
LOG(INFO) << A.IndexWithNegativeIndices(0, -1);
}

0 comments on commit 80f9b9f

Please sign in to comment.