Skip to content

Commit

Permalink
Modified squeezes Python APIs to fully support Dynamic Shape Defini…
Browse files Browse the repository at this point in the history
…tions (#1325)

Old syntax:
```python
ops.squeeze(arg: Tensor, original_shape: List[int], dims: List[int]) -> Tensor
```
New Syntax:
```python
ops.squeeze(arg: Tensor, dims: List[int]) -> Tensor
```

The new syntax removes baking in the original tensor shape with a
constant shape so it can be dynamic.

This PR also includes:
* Squeeze Python API input testing and error testing
* Replaces old syntax in `test_python_frontend.py`
* Bumps version number to `0.1.5`

---------

Co-authored-by: jjsjann123 <[email protected]>
  • Loading branch information
kevinstephano and jjsjann123 authored Jan 10, 2024
1 parent ed9f9dd commit 61b0ab6
Show file tree
Hide file tree
Showing 12 changed files with 167 additions and 159 deletions.
153 changes: 65 additions & 88 deletions csrc/ops/alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,26 +205,46 @@ TensorView* flatten(TensorView* x, int64_t start_dim, int64_t end_dim) {
return out;
}

TensorView* squeeze(TensorView* x, const std::vector<bool>& to_squeeze) {
TensorView* squeeze(TensorView* x, const std::vector<int64_t>& dims) {
NVF_ERROR(x != nullptr, "Input is invalid.");
auto x_dom = x->domain()->noReductions();
const auto ndims = static_cast<int>(x_dom.size());

NVF_ERROR(
ndims == (int)to_squeeze.size(),
"Invalid to_squeeze for squeeze: ",
to_squeeze,
". Input tensor: ",
x->toString());
(int)dims.size() <= ndims,
"The dims to squeeze must be <= the number of dims of the input tensor. ",
"Squeeze dims: ",
dims.size(),
" Input Tensor dims: ",
ndims);

std::vector<bool> to_squeeze(ndims, false);
for (auto dim : dims) {
// Handle negative relative to the end dimensions specifications
if (dim < 0) {
dim = static_cast<int64_t>(to_squeeze.size()) + dim;
}
NVF_CHECK(
(dim >= 0) && (static_cast<size_t>(dim) < to_squeeze.size()),
"Squeeze dim is outside of Tensor size! Tensor Size: ",
to_squeeze.size(),
" Dim: ",
dim);
to_squeeze[dim] = true;
}

std::vector<IterDomain*> out_domain;
for (const auto idx : c10::irange(ndims)) {
auto id = x_dom[idx];
if (to_squeeze[idx]) {
if (!id->isSymbolic()) {
NVF_CHECK(
id->isBroadcast(),
"Can not squeeze non-broadcasting dimension(s).");
// If a squeeze is attempted on a non-broadcast dimension
// just don't do it! This conforms with Pytorch.
if (!id->isBroadcast()) {
to_squeeze[idx] = false;
out_domain.push_back(id->cloneWithoutRFactor());
continue;
}
NVF_CHECK(
!id->hasExpandedExtent(), "Can not squeeze expanded dimension(s).");
NVF_CHECK(
Expand All @@ -241,99 +261,56 @@ TensorView* squeeze(TensorView* x, const std::vector<bool>& to_squeeze) {
out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)),
*x->getDataType());

IrBuilder::create<SqueezeOp>(x->container(), out, x, to_squeeze);

return out;
}

TensorView* squeeze(TensorView* x, const std::vector<int64_t>& sizes) {
NVF_ERROR(x != nullptr, "Input is invalid.");
const auto ndims = static_cast<int>(x->domain()->noReductions().size());

NVF_ERROR(
ndims == int(sizes.size()),
"Invalid sizes for squeeze: ",
sizes,
". Input tensor: ",
x->toString());

std::vector<bool> to_squeeze(ndims);
for (const auto idx : c10::irange(sizes.size())) {
to_squeeze[idx] = (sizes[idx] == 1);
}
return squeeze(x, to_squeeze);
}

TensorView* squeeze(TensorView* x, const std::vector<int64_t>& sizes, int dim) {
NVF_ERROR(x != nullptr, "Input is invalid.");
const auto ndims = static_cast<int>(x->domain()->noReductions().size());

NVF_ERROR(
ndims == int(sizes.size()),
"Invalid sizes for squeeze: ",
sizes,
". Input tensor: ",
x->toString());

if (dim < 0) {
dim = ndims + dim;
}

NVF_ERROR(
dim >= 0 && dim < ndims,
"Invalid position to squeeze: ",
dim,
". Input tensor: ",
x->toString());

if (sizes[dim] == 1) {
std::vector<bool> to_squeeze(ndims, false);
to_squeeze[dim] = true;
return squeeze(x, to_squeeze);
std::vector<bool> all_false(to_squeeze.size(), false);
// If a squeeze does not perform a squeeze, create a no-op
if (to_squeeze == all_false) {
IrBuilder::create<LoadStoreOp>(LoadStoreOpType::Set, out, x);
} else {
return set(x);
IrBuilder::create<SqueezeOp>(x->container(), out, x, to_squeeze);
}

return out;
}

TensorView* squeeze(
TensorView* x,
const std::vector<int64_t>& sizes,
const std::vector<int64_t>& dims) {
TensorView* squeeze(TensorView* x, const std::vector<bool>& to_squeeze) {
NVF_ERROR(x != nullptr, "Input is invalid.");
const auto ndims = static_cast<int>(x->domain()->noReductions().size());
auto x_dom = x->domain()->noReductions();
const auto ndims = static_cast<int>(x_dom.size());

NVF_ERROR(
ndims == int(sizes.size()),
"Invalid sizes for squeeze: ",
sizes,
ndims == (int)to_squeeze.size(),
"Invalid to_squeeze for squeeze: ",
to_squeeze,
". Input tensor: ",
x->toString());

bool is_all_singleton_dimensions = true;

std::vector<bool> to_squeeze(ndims);
for (auto dim : dims) {
if (dim < 0) {
dim = ndims + dim;
std::vector<IterDomain*> out_domain;
for (const auto idx : c10::irange(ndims)) {
auto id = x_dom[idx];
if (to_squeeze[idx]) {
if (!id->isSymbolic()) {
NVF_CHECK(
id->isBroadcast(),
"Can not squeeze non-broadcasting dimension(s).");
NVF_CHECK(
!id->hasExpandedExtent(), "Can not squeeze expanded dimension(s).");
NVF_CHECK(
id->extent()->isConstScalar() && id->extent()->evaluate() == 1,
"Can not squeeze dimension(s) with size != 1.");
}
} else {
out_domain.push_back(id->cloneWithoutRFactor());
}
}

NVF_ERROR(
dim >= 0 && dim < ndims,
"Invalid position to squeeze: ",
dim,
". Input tensor: ",
x->toString());
auto out = IrBuilder::create<TensorView>(
IrBuilder::create<TensorDomain>(
out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)),
*x->getDataType());

bool is_singleton_dim = (sizes[dim] == 1);
to_squeeze.at(dim) = is_singleton_dim;
is_all_singleton_dimensions &= is_singleton_dim;
}
IrBuilder::create<SqueezeOp>(x->container(), out, x, to_squeeze);

if (is_all_singleton_dimensions) {
return squeeze(x, to_squeeze);
} else {
return set(x);
}
return out;
}

TensorView* unsqueeze(TensorView* x, int dim) {
Expand Down
14 changes: 5 additions & 9 deletions csrc/ops/alias.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,12 @@ TensorView* reshape(TensorView* x, const std::vector<Val*>& new_sizes);

TensorView* flatten(TensorView* x, int64_t start_dim = 0, int64_t end_dim = -1);

TensorView* squeeze(TensorView* x, const std::vector<bool>& to_squeeze);

TensorView* squeeze(TensorView* x, const std::vector<int64_t>& sizes);

TensorView* squeeze(TensorView* x, const std::vector<int64_t>& sizes, int dim);
// This implementation is specific to Pytorch where if you attempt to squeeze
// a non-broadcast dimension, the squeeze does not do anything to that
// dimension and does not trigger an error.
TensorView* squeeze(TensorView* x, const std::vector<int64_t>& dims);

TensorView* squeeze(
TensorView* x,
const std::vector<int64_t>& sizes,
const std::vector<int64_t>& dims);
TensorView* squeeze(TensorView* x, const std::vector<bool>& to_squeeze);

TensorView* unsqueeze(TensorView* x, int dim);

Expand Down
53 changes: 8 additions & 45 deletions csrc/python_frontend/fusion_record.h
Original file line number Diff line number Diff line change
Expand Up @@ -582,84 +582,49 @@ struct SqueezeOpRecord : RecordFunctor {
SqueezeOpRecord(
std::vector<State> _args,
std::vector<State> _outputs,
std::vector<int64_t> original_shape,
std::vector<int64_t> dims)
: RecordFunctor(
std::move(_args),
std::move(_outputs),
"ops.squeeze",
serde::RecordType::SqueezeOp),
original_shape_(std::move(original_shape)),
dims_(std::move(dims)) {}
~SqueezeOpRecord() override = default;
RecordFunctor* clone() final {
return new SqueezeOpRecord(*this);
}

//! Child specific hash function in lower 32 bits.
//! | 31 -------------- 16 | 15 -------------- 0 |
//! | Squeeze Dim hash | original_shape hash |
//! | 31 ------------------------------------- 0 |
//! | Squeeze Dim hash |
size_t hash() const final {
auto result = RecordFunctor::hash();
size_t original_shape_hash = 0;
for (auto shape : original_shape_) {
original_shape_hash ^= static_cast<size_t>(shape);
}
size_t squeeze_dims_hash = 0;
for (auto dim : dims_) {
squeeze_dims_hash ^= static_cast<size_t>(dim);
}
squeeze_dims_hash = (squeeze_dims_hash & 0xffff) << 16;
return result | squeeze_dims_hash | (original_shape_hash & 0xffff);
result = result | (squeeze_dims_hash & 0xffffffff);
return result;
}

bool operator==(const RecordFunctor& other) const final {
auto result = false;
if (auto child_ptr = dynamic_cast<const SqueezeOpRecord*>(&other)) {
result = RecordFunctor::operator==(other);
if (result) {
result = (original_shape_.size() == child_ptr->original_shape_.size());
if (result) {
for (size_t i = 0; i < dims_.size(); ++i) {
if (dims_[i] != child_ptr->dims_[i]) {
result = false;
break;
}
}
}
if (result) {
for (size_t i = 0; i < original_shape_.size(); ++i) {
if (original_shape_[i] != child_ptr->original_shape_[i]) {
result = false;
break;
}
}
}
}
result = RecordFunctor::operator==(other) && (dims_ == child_ptr->dims_);
}
return result;
}

void operator()(FusionState& fd) final {
auto arg = fd.getFusionState(args_.at(0).index)->template as<TensorView>();
auto output = squeeze(arg, original_shape_, dims_);
auto output = squeeze(arg, dims_);
fd.setFusionState(outputs_.at(0).index, output);
}

void print(std::ostream& os, bool close_function = true) const final {
RecordFunctor::print(os, false);
os << ", original_shape=[";
os << ", dims=[";
bool first_arg = true;
for (auto shape : original_shape_) {
if (first_arg) {
first_arg = false;
} else {
os << ", ";
}
os << shape;
}
os << "], dims=[";
first_arg = true;
for (auto dim : dims_) {
if (first_arg) {
first_arg = false;
Expand All @@ -678,12 +643,10 @@ struct SqueezeOpRecord : RecordFunctor {
flatbuffers::FlatBufferBuilder& builder) const final {
return {
serde::RecordData::Squeeze,
serde::CreateSqueezeDirect(builder, &original_shape_, &dims_).Union()};
serde::CreateSqueezeDirect(builder, &dims_).Union()};
}

private:
//! Represents the tensor dimensions of the input tensor.
std::vector<int64_t> original_shape_;
//! Dimension to squeeze.
std::vector<int64_t> dims_;
};
Expand Down
3 changes: 0 additions & 3 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2580,7 +2580,6 @@ void initNvFuserPythonBindings(PyObject* module) {
"squeeze",
[](FusionDefinition::Operators& self,
Tensor arg,
std::vector<int64_t>& original_shape,
std::vector<int64_t>& dims) -> Tensor {
FUSER_PERF_SCOPE("Operators.squeeze");
NVF_CHECK(
Expand All @@ -2590,12 +2589,10 @@ void initNvFuserPythonBindings(PyObject* module) {
fd->defineRecord(new SqueezeOpRecord(
{fd->recordingState(arg())},
{fd->recordingState(output())},
std::move(original_shape),
std::move(dims)));
return output;
},
py::arg("arg"),
py::arg("original_shape"),
py::arg("dims"),
py::return_value_policy::reference);
nvf_ops.def(
Expand Down
1 change: 0 additions & 1 deletion csrc/serde/fusion_cache.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,6 @@ table Slice {

// Data for SqueezeOpRecord
table Squeeze {
original_shape: [long];
squeeze_dims: [long];
}

Expand Down
1 change: 0 additions & 1 deletion csrc/serde/fusion_record.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,6 @@ void RecordFunctorFactory::registerAllParsers() {
return new python_frontend::SqueezeOpRecord(
parseStateArgs(buffer->args()),
parseStateArgs(buffer->outputs()),
parseVector(data->original_shape()),
parseVector(data->squeeze_dims()));
};
registerParser(RecordType::SqueezeOp, deserializeSqueezeRecord);
Expand Down
Loading

0 comments on commit 61b0ab6

Please sign in to comment.