Skip to content

Commit

Permalink
Autograd indices/values and sparse_coo ctor (pytorch#13001)
Browse files Browse the repository at this point in the history
Summary:
Reopen of pytorch#11253 after fixing bug in index_select
Pull Request resolved: pytorch#13001

Differential Revision: D10514987

Pulled By: SsnL

fbshipit-source-id: 399a83a1d3246877a3523baf99aaf1ce8066f33f
  • Loading branch information
ssnl authored and facebook-github-bot committed Oct 24, 2018
1 parent e0f21a4 commit 46162cc
Show file tree
Hide file tree
Showing 46 changed files with 1,462 additions and 995 deletions.
4 changes: 3 additions & 1 deletion aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -3266,7 +3266,9 @@
name: alias
return: THTensor*
cpu_half: True
variants: [function]
variants:
- method
- function
options:
- cname: newWithTensor
arguments:
Expand Down
27 changes: 16 additions & 11 deletions aten/src/ATen/SparseTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ namespace {
// a scalar and have one element)
//
// Thus, an empty sparse tensor should be a 1-dimensional tensor of size [0].
// Furthermore, we have dim == sparseDims + denseDims; since this is a sparse
// tensor, let us say that an empty sparse tensor has sparseDims == 1 and
// denseDims == 0. (There is a degree of freedom here, but given that this
// is a sparse dimension, it seems reasonable to demand that sparseDims > 0).
// Furthermore, we have dim == sparse_dim + dense_dim; since this is a sparse
// tensor, let us say that an empty sparse tensor has sparse_dim == 1 and
// dense_dim == 0. (There is a degree of freedom here, but given that this
// is a sparse dimension, it seems reasonable to demand that sparse_dim > 0).
//
// This means that we allocate a [1,0] size indices tensor and a [0] size
// values tensor for such an empty tensor.
SparseTensorImpl::SparseTensorImpl(at::TensorTypeId type_id, const caffe2::TypeMeta& data_type)
: TensorImpl(type_id, data_type, nullptr, false)
, size_{0}
, sparseDims_(1)
, denseDims_(0)
, sparse_dim_(1)
, dense_dim_(0)
, indices_(at::empty({1, 0}, at::initialTensorOptions().device(sparseTensorIdToDeviceType(type_id)).dtype(ScalarType::Long)))
, values_(at::empty({0}, at::initialTensorOptions().device(sparseTensorIdToDeviceType(type_id)).dtype(dataTypeToScalarType(data_type.id())))) {}

Expand Down Expand Up @@ -67,7 +67,7 @@ void SparseTensorImpl::set_storage_offset(int64_t storage_offset) {
}

int64_t SparseTensorImpl::dim() const {
return sparseDims_ + denseDims_;
return sparse_dim_ + dense_dim_;
}
TensorImpl* SparseTensorImpl::maybe_zero_dim(bool condition_when_zero_dim) {
AT_CHECK(condition_when_zero_dim == (dim() == 0),
Expand All @@ -83,17 +83,22 @@ int64_t SparseTensorImpl::storage_offset() const {
AT_ERROR("sparse tensors do not have storage");
}
void SparseTensorImpl::set_indices_and_values_unsafe(const Tensor& indices, const Tensor& values) {
AT_ASSERT(!indices.is_variable() && !values.is_variable()); // They should be plain tensors!

AT_CHECK(!indices.is_sparse(), "expected indices to be a dense tensor, but got indices of layout ", indices.layout());
AT_CHECK(!values.is_sparse(), "expected values to be a dense tensor, but got values of layout ", values.layout());

AT_CHECK(values.type().toSparse() == type(), "values type must match sparse tensor type");
AT_CHECK(indices.type().scalarType() == kLong, "indices must be an int64 tensor");
AT_CHECK(indices.type().backend() == values.type().backend(), "backend of indices (", indices.type().backend(), ") must match backend of values (", values.type().backend(), ")");
AT_CHECK(!indices.is_cuda() || indices.get_device() == values.get_device(), "device of indices (", indices.get_device(), ") must match device of values (", values.get_device(), ")");

AT_CHECK(indices.dim() == 2, "indices must be nDim x nnz, but got: ", indices.sizes());
AT_CHECK(indices.dim() == 2, "indices must be sparse_dim x nnz, but got: ", indices.sizes());
AT_CHECK(indices.size(1) == values.size(0), "indices and values must have same nnz, but got nnz from indices: ", indices.size(1), ", nnz from values: ", values.size(0));
AT_CHECK(indices.size(0) == sparseDims_, "indices has incorrect first dimension, expected ", sparseDims_, ", got ", indices.size(0));
AT_CHECK(values.dim() == denseDims_ + 1, "values has incorrect number of dimensions, expected ", denseDims_ + 1, ", got ", values.dim());
AT_CHECK(indices.size(0) == sparse_dim_, "indices has incorrect first dimension, expected ", sparse_dim_, ", got ", indices.size(0));
AT_CHECK(values.dim() == dense_dim_ + 1, "values has incorrect number of dimensions, expected ", dense_dim_ + 1, ", got ", values.dim());

auto dense_size_original = sizes().slice(sparseDims_);
auto dense_size_original = sizes().slice(sparse_dim_);
std::vector<int64_t> expected_values_size_vec = {values.size(0)};
expected_values_size_vec.insert(expected_values_size_vec.end(), dense_size_original.begin(), dense_size_original.end());
IntList expected_values_size(expected_values_size_vec);
Expand Down
93 changes: 46 additions & 47 deletions aten/src/ATen/SparseTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl {
// Stored in COO format, indices + values.

// INVARIANTS:
// _sparseDims: range [0, len(shape)]; _sparseDims + _denseDims = len(shape)
// _denseDims : range [0, len(shape)]; _sparseDims + _denseDims = len(shape)
// _indices.shape: dimensionality: 2, shape: (_sparseDims, nnz)
// _values.shape: dimensionality: 1 + _denseDims. shape: (nnz, shape[_sparseDims:])
// sparse_dim: range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
// dense_dim : range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
// _indices.shape: dimensionality: 2, shape: (sparse_dim, nnz)
// _values.shape: dimensionality: 1 + dense_dim. shape: (nnz, shape[sparse_dim:])

// The true size of the sparse tensor (e.g., if you called to_dense()
// on it). When THTensor merges into TensorImpl, this field
// should move to the parent class.
std::vector<int64_t> size_;

int64_t sparseDims_ = 0; // number of sparse dimensions
int64_t denseDims_ = 0; // number of dense dimensions
int64_t sparse_dim_ = 0; // number of sparse dimensions
int64_t dense_dim_ = 0; // number of dense dimensions

Tensor indices_; // always a LongTensor
Tensor values_;
Expand All @@ -39,8 +39,8 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl {
explicit SparseTensorImpl(at::TensorTypeId, const caffe2::TypeMeta&);

int64_t nnz() const { return values_.size(0); }
int64_t sparseDims() const { return sparseDims_; }
int64_t denseDims() const { return denseDims_; }
int64_t sparse_dim() const { return sparse_dim_; }
int64_t dense_dim() const { return dense_dim_; }
bool coalesced() const { return coalesced_; }
Tensor indices() const { return indices_; }
Tensor values() const { return values_; }
Expand All @@ -60,16 +60,16 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl {
const Storage& storage() const override;
int64_t storage_offset() const override;

// WARNING: This function does NOT preserve invariants of sparseDims/denseDims with
// WARNING: This function does NOT preserve invariants of sparse_dim/dense_dim with
// respect to indices and values
void raw_resize_(int64_t sparseDims, int64_t denseDims, IntList size) {
void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntList size) {
size_ = size.vec();
sparseDims_ = sparseDims;
denseDims_ = denseDims;
sparse_dim_ = sparse_dim;
dense_dim_ = dense_dim;
refresh_numel();
}

// NOTE: This function preserves invariants of sparseDims/denseDims with respect to
// NOTE: This function preserves invariants of sparse_dim/dense_dim with respect to
// indices and values.
//
// NOTE: This function supports the following cases:
Expand All @@ -91,75 +91,73 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl {
// and for API consistency we don't support it).
// 4. When we attempt to shrink the size of any of the sparse dimensions on a non-empty sparse tensor
// (this could make some of the stored indices out-of-bound and thus unsafe).
void resize_(int64_t sparseDims, int64_t denseDims, IntList size) {
AT_CHECK(sparseDims + denseDims == size.size(), "number of dimensions must be sparseDims (", sparseDims, ") + denseDims (", denseDims, "), but got ", size.size());
void resize_(int64_t sparse_dim, int64_t dense_dim, IntList size) {
AT_CHECK(sparse_dim + dense_dim == size.size(), "number of dimensions must be sparse_dim (", sparse_dim, ") + dense_dim (", dense_dim, "), but got ", size.size());
if (nnz() > 0) {
auto alt_options_msg = "You could try the following options:\n\
1. If you need an empty sparse tensor of this size, call `x=torch.sparse_coo_tensor(size)`.\n\
1. If you need an empty sparse tensor of this size, call `x = torch.sparse_coo_tensor(size)`.\n\
2. If you need to resize this tensor, you have the following options:\n\
1. For both sparse and dense dimensions, keep the number of them constant and the size of them non-shrinking, and then try the same call again.\n\
2. Or, create a new sparse tensor with the correct indices and values from this sparse tensor.";

AT_CHECK(sparseDims == sparseDims_,
"changing the number of sparse dimensions (from ", sparseDims_, " to ", sparseDims, ") on a non-empty sparse tensor is not supported.\n", alt_options_msg);
AT_CHECK(sparse_dim == sparse_dim_,
"changing the number of sparse dimensions (from ", sparse_dim_, " to ", sparse_dim, ") on a non-empty sparse tensor is not supported.\n", alt_options_msg);

AT_CHECK(denseDims == denseDims_,
"changing the number of dense dimensions (from ", denseDims_, " to ", denseDims, ") on a non-empty sparse tensor is not supported.\n", alt_options_msg);
AT_CHECK(dense_dim == dense_dim_,
"changing the number of dense dimensions (from ", dense_dim_, " to ", dense_dim, ") on a non-empty sparse tensor is not supported.\n", alt_options_msg);

bool shrinking_sparse_dims = false;
bool shrinking_dense_dims = false;
auto sparse_size_original = sizes().slice(0, sparseDims);
auto sparse_size_new = size.slice(0, sparseDims);
for (int i = 0; i < sparseDims; i++) {
bool shrinking_dense_dim = false;
auto sparse_size_original = sizes().slice(0, sparse_dim);
auto sparse_size_new = size.slice(0, sparse_dim);
for (int i = 0; i < sparse_dim; i++) {
if (sparse_size_new[i] < sparse_size_original[i]) {
shrinking_sparse_dims = true;
break;
}
}
auto dense_size_original = sizes().slice(sparseDims);
auto dense_size_new = size.slice(sparseDims);
for (int i = 0; i < denseDims; i++) {
auto dense_size_original = sizes().slice(sparse_dim);
auto dense_size_new = size.slice(sparse_dim);
for (int i = 0; i < dense_dim; i++) {
if (dense_size_new[i] < dense_size_original[i]) {
shrinking_dense_dims = true;
shrinking_dense_dim = true;
break;
}
}

AT_CHECK(!shrinking_sparse_dims,
"shrinking the size of sparse dimensions (from ", sparse_size_original, " to ", sparse_size_new, ") on a non-empty sparse tensor is not supported.\n", alt_options_msg);

AT_CHECK(!shrinking_dense_dims,
AT_CHECK(!shrinking_dense_dim,
"shrinking the size of dense dimensions (from ", dense_size_original, " to ", dense_size_new, ") on a non-empty sparse tensor is not supported.\n", alt_options_msg);
}

if ((!size.equals(size_)) || (sparseDims != sparseDims_) || (denseDims != denseDims_)) {
std::vector<int64_t> values_size = {values().size(0)};
auto dense_size = size.slice(sparseDims);
if ((!size.equals(size_)) || (sparse_dim != sparse_dim_) || (dense_dim != dense_dim_)) {
auto nnz = values().size(0);
std::vector<int64_t> values_size = {nnz};
auto dense_size = size.slice(sparse_dim);
values_size.insert(values_size.end(), dense_size.begin(), dense_size.end());
values_.resize_(values_size);

std::vector<int64_t> indices_size = indices().sizes().vec();
indices_size[0] = sparseDims;
indices_.resize_(indices_size);
indices_.resize_({sparse_dim, nnz});
}

size_ = size.vec();
sparseDims_ = sparseDims;
denseDims_ = denseDims;
sparse_dim_ = sparse_dim;
dense_dim_ = dense_dim;
refresh_numel();
}

// NOTE: this function will resize the sparse tensor and also set `indices` and `values` to empty.
void resize_and_clear_(int64_t sparseDims, int64_t denseDims, IntList size) {
AT_CHECK(sparseDims + denseDims == size.size(), "number of dimensions must be sparseDims (", sparseDims, ") + denseDims (", denseDims, "), but got ", size.size());
void resize_and_clear_(int64_t sparse_dim, int64_t dense_dim, IntList size) {
AT_CHECK(sparse_dim + dense_dim == size.size(), "number of dimensions must be sparse_dim (", sparse_dim, ") + dense_dim (", dense_dim, "), but got ", size.size());

size_ = size.vec();
sparseDims_ = sparseDims;
denseDims_ = denseDims;
sparse_dim_ = sparse_dim;
dense_dim_ = dense_dim;

auto empty_indices = at::empty({sparseDims, 0}, indices().options());
auto empty_indices = at::empty({sparse_dim, 0}, indices().options());
std::vector<int64_t> values_size = {0};
auto dense_size = sizes().slice(sparseDims);
auto dense_size = sizes().slice(sparse_dim);
values_size.insert(values_size.end(), dense_size.begin(), dense_size.end());
auto empty_values = at::empty(values_size, values().options());
set_indices_and_values_unsafe(empty_indices, empty_values);
Expand All @@ -169,9 +167,10 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl {
void set_coalesced(bool coalesced) { coalesced_ = coalesced; }

// NOTE: this function is only used internally and not exposed to Python frontend
void set_nnz_and_narrow(int64_t nnz) {
indices_ = indices_.narrow(1, 0, nnz);
values_ = values_.narrow(0, 0, nnz);
void set_nnz_and_narrow(int64_t new_nnz) {
AT_ASSERT(new_nnz <= nnz());
indices_ = indices_.narrow(1, 0, new_nnz);
values_ = values_.narrow(0, 0, new_nnz);
}

// Takes indices and values and directly puts them into the sparse tensor, no copy.
Expand Down
111 changes: 111 additions & 0 deletions aten/src/ATen/SparseTensorUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#include <ATen/ATen.h>
#include <ATen/SparseTensorImpl.h>

namespace at { namespace sparse {

// Just for documentary purposes
using SparseTensor = Tensor;
using LongTensor = Tensor;
using IntTensor = Tensor;
using SparseType = Type;

// This is an internal utility function for getting at the SparseTensorImpl,
// so that we can write sparse tensor specific accessors for special fields
// in SparseTensor. You should only use this for writing low level
// setters/getters for SparseTensorImpl fields; otherwise, you should use
// the low level setters/getters that were implemented using this.
//
// This may be called repeatedly, so make sure it's pretty cheap.
inline SparseTensorImpl* get_sparse_impl(const SparseTensor& self) {
AT_ASSERTM(!self.is_variable(), "_internal_get_SparseTensorImpl: should not be a variable");
AT_ASSERTM(self.is_sparse(), "_internal_get_SparseTensorImpl: not a sparse tensor");
return static_cast<SparseTensorImpl*>(self.unsafeGetTensorImpl());
}

// Port of the old THCSTensor_(checkGPU), but it doesn't really belong here
// because it is more general
// NB: I dropped kernelP2PEnabled support
// NB: This only works if the tensors are KNOWN to be CUDA.
// TODO: Generalize it so it works on CPU as well
inline bool check_device(ArrayRef<Tensor> ts) {
if (ts.empty()) {
return true;
}
int64_t curDevice = current_device();
for (const Tensor& t : ts) {
if (t.get_device() != curDevice) return false;
}
return true;
}

// Takes indices and values and directly puts them into the sparse tensor, no
// copy. This used to be called THSTensor_(_move)
inline void alias_into_sparse(const SparseTensor& self, const LongTensor& indices, const Tensor& values) {
get_sparse_impl(self)->set_indices_and_values_unsafe(indices, values);
}

// Take indices and values and makes a (data) copy of them to put into the sparse
// indices/values. This used to be called THSTensor_(_set)
inline void copy_into_sparse(const SparseTensor& self, const LongTensor& indices, const Tensor& values, bool non_blocking) {
alias_into_sparse(self, self._indices().type().copy(indices, non_blocking), self._values().type().copy(values, non_blocking));
}

// TODO: put this into the public API
inline bool is_same_tensor(const Tensor& lhs, const Tensor& rhs) {
return lhs.unsafeGetTensorImpl() == rhs.unsafeGetTensorImpl();
}

inline bool is_same_density(const SparseTensor& self, const SparseTensor& src) {
return self.sparse_dim() == src.sparse_dim() && self.dense_dim() == src.dense_dim();
}

// Give us a new values tensor, with the same dimensionality
// as 'values' but with a new number of non-zero elements.
// TODO: Expose this for real in ATen, some day?
// NB: Doesn't preserve data.
inline Tensor new_values_with_size_of(const Tensor& values, int64_t nnz) {
std::vector<int64_t> size = values.sizes().vec();
size[0] = nnz;
return at::empty(size, values.options());
}

// This helper function flattens a sparse indices tensor (a LongTensor) into a 1D
// indices tensor. E.g.,
// input = [[2, 4, 0],
// [3, 1, 10]]
// full_size = [2, 12]
// output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10]
//
// In other words, assuming that each `indices[i, :]` is a valid index to a
// tensor `t` of shape `full_size`. This returns the corresponding indices to
// the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`.
// if forceClone is true, the result will forced to be a clone of self.
// if force_clone is true, the result will forced to be a clone of self.
inline LongTensor flatten_indices(const Tensor& indices, IntList full_size, bool force_clone = false) {
int64_t sparse_dim = indices.size(0);
if (sparse_dim == 1) {
if (force_clone) {
return indices.squeeze(0).clone();
} else {
return indices.squeeze(0);
}
} else {
std::vector<int64_t> indices_mult_cpu_vec;
indices_mult_cpu_vec.reserve(sparse_dim);
int64_t mult = 1;
for (int64_t i = sparse_dim - 1; i >= 0; i--) {
indices_mult_cpu_vec[i] = mult;
mult *= full_size[i];
}
auto indices_mult_cpu = indices.type().cpu()
.tensorFromBlob(indices_mult_cpu_vec.data(), /*size=*/{sparse_dim, 1});
// NB: must be blocking because this blob may be freed after this closure,
// and non_blocking copy will see garbage.
auto indices_mult = indices_mult_cpu.to(indices.device(), /*non_blocking=*/false);
// Ideally we want matmul but matmul is slow on CPU Long and not implemented
// on CUDA Long. So mul is faster.
return indices.mul(indices_mult).sum(0);
}
}

}} // namespace at::sparse
14 changes: 10 additions & 4 deletions aten/src/ATen/core/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ class CAFFE2_API Tensor {
Tensor & log_normal_(double mean=1, double std=2, Generator * generator=nullptr);
Tensor & exponential_(double lambd=1, Generator * generator=nullptr);
Tensor & geometric_(double p, Generator * generator=nullptr);
Tensor alias() const;
Tensor abs() const;
Tensor & abs_();
Tensor acos() const;
Expand Down Expand Up @@ -621,17 +622,22 @@ class CAFFE2_API Tensor {
Tensor & sub_(Scalar other, Scalar alpha=1);
Tensor addmm(const Tensor & mat1, const Tensor & mat2, Scalar beta=1, Scalar alpha=1) const;
Tensor & addmm_(const Tensor & mat1, const Tensor & mat2, Scalar beta=1, Scalar alpha=1);
Tensor & sparse_resize_(IntList size, int64_t sparseDims, int64_t denseDims);
Tensor & sparse_resize_and_clear_(IntList size, int64_t sparseDims, int64_t denseDims);
Tensor & sparse_resize_(IntList size, int64_t sparse_dim, int64_t dense_dim);
Tensor & sparse_resize_and_clear_(IntList size, int64_t sparse_dim, int64_t dense_dim);
Tensor sparse_mask(SparseTensorRef mask) const;
Tensor to_dense() const;
int64_t _sparseDims() const;
int64_t _denseDims() const;
int64_t sparse_dim() const;
int64_t _dimI() const;
int64_t dense_dim() const;
int64_t _dimV() const;
int64_t _nnz() const;
Tensor coalesce() const;
bool is_coalesced() const;
Tensor _indices() const;
Tensor _values() const;
Tensor & _coalesced_(bool coalesced);
Tensor indices() const;
Tensor values() const;
int64_t numel() const;
std::vector<Tensor> unbind(int64_t dim=0) const;
int64_t get_device() const;
Expand Down
Loading

0 comments on commit 46162cc

Please sign in to comment.