Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Aug 3, 2017
1 parent 9ce878f commit c490123
Showing 1 changed file with 8 additions and 18 deletions.
26 changes: 8 additions & 18 deletions src/operator/tensor/sparse_retain-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,9 @@ inline bool SparseRetainForwardInferStorageType(const nnvm::NodeAttrs& attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
if ((*in_attrs)[sr::kArr] == kRowSparseStorage) {
STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, sr::kIdx, kDefaultStorage);
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, sr::kOut, kRowSparseStorage);
} else { // fallback
type_assign(&(in_attrs->at(sr::kArr)), kDefaultStorage);
type_assign(&(in_attrs->at(sr::kIdx)), kDefaultStorage);
type_assign(&(out_attrs->at(sr::kOut)), kDefaultStorage);
}
type_assign(&(in_attrs->at(sr::kArr)), kRowSparseStorage);
type_assign(&(in_attrs->at(sr::kIdx)), kDefaultStorage);
type_assign(&(out_attrs->at(sr::kOut)), kRowSparseStorage);
return true;
}

Expand All @@ -74,16 +69,11 @@ inline bool SparseRetainBackwardInferStorageType(const nnvm::NodeAttrs& attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 2U);
if (out_attrs->at(sr::kArr) == kRowSparseStorage) {
STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, sr::kOut, kDefaultStorage);
STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, sr::kIdx, kDefaultStorage);
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, sr::kIdx, kDefaultStorage);
} else {
type_assign(&(in_attrs->at(sr::kOut)), kDefaultStorage);
type_assign(&(in_attrs->at(sr::kIdx)), kDefaultStorage);
type_assign(&(out_attrs->at(sr::kArr)), kDefaultStorage);
type_assign(&(out_attrs->at(sr::kIdx)), kDefaultStorage);
}

type_assign(&(in_attrs->at(sr::kOut)), kDefaultStorage);
type_assign(&(in_attrs->at(sr::kIdx)), kDefaultStorage);
type_assign(&(out_attrs->at(sr::kArr)), kRowSparseStorage);
type_assign(&(out_attrs->at(sr::kIdx)), kDefaultStorage);
return true;
}

Expand Down

0 comments on commit c490123

Please sign in to comment.