diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index fd186232643..29d39a8521a 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -205,26 +205,46 @@ TensorView* flatten(TensorView* x, int64_t start_dim, int64_t end_dim) { return out; } -TensorView* squeeze(TensorView* x, const std::vector& to_squeeze) { +TensorView* squeeze(TensorView* x, const std::vector& dims) { NVF_ERROR(x != nullptr, "Input is invalid."); auto x_dom = x->domain()->noReductions(); const auto ndims = static_cast(x_dom.size()); NVF_ERROR( - ndims == (int)to_squeeze.size(), - "Invalid to_squeeze for squeeze: ", - to_squeeze, - ". Input tensor: ", - x->toString()); + (int)dims.size() <= ndims, + "The dims to squeeze must be <= the number of dims of the input tensor. ", + "Squeeze dims: ", + dims.size(), + " Input Tensor dims: ", + ndims); + + std::vector to_squeeze(ndims, false); + for (auto dim : dims) { + // Handle negative relative to the end dimensions specifications + if (dim < 0) { + dim = static_cast(to_squeeze.size()) + dim; + } + NVF_CHECK( + (dim >= 0) && (static_cast(dim) < to_squeeze.size()), + "Squeeze dim is outside of Tensor size! Tensor Size: ", + to_squeeze.size(), + " Dim: ", + dim); + to_squeeze[dim] = true; + } std::vector out_domain; for (const auto idx : c10::irange(ndims)) { auto id = x_dom[idx]; if (to_squeeze[idx]) { if (!id->isSymbolic()) { - NVF_CHECK( - id->isBroadcast(), - "Can not squeeze non-broadcasting dimension(s)."); + // If a squeeze is attempted on a non-broadcast dimension + // just don't do it! This conforms with Pytorch. + if (!id->isBroadcast()) { + to_squeeze[idx] = false; + out_domain.push_back(id->cloneWithoutRFactor()); + continue; + } NVF_CHECK( !id->hasExpandedExtent(), "Can not squeeze expanded dimension(s)."); NVF_CHECK( @@ -241,99 +261,56 @@ TensorView* squeeze(TensorView* x, const std::vector& to_squeeze) { out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)), *x->getDataType()); - IrBuilder::create(x->container(), out, x, to_squeeze); - - return out; -} - -TensorView* squeeze(TensorView* x, const std::vector& sizes) { - NVF_ERROR(x != nullptr, "Input is invalid."); - const auto ndims = static_cast(x->domain()->noReductions().size()); - - NVF_ERROR( - ndims == int(sizes.size()), - "Invalid sizes for squeeze: ", - sizes, - ". Input tensor: ", - x->toString()); - - std::vector to_squeeze(ndims); - for (const auto idx : c10::irange(sizes.size())) { - to_squeeze[idx] = (sizes[idx] == 1); - } - return squeeze(x, to_squeeze); -} - -TensorView* squeeze(TensorView* x, const std::vector& sizes, int dim) { - NVF_ERROR(x != nullptr, "Input is invalid."); - const auto ndims = static_cast(x->domain()->noReductions().size()); - - NVF_ERROR( - ndims == int(sizes.size()), - "Invalid sizes for squeeze: ", - sizes, - ". Input tensor: ", - x->toString()); - - if (dim < 0) { - dim = ndims + dim; - } - - NVF_ERROR( - dim >= 0 && dim < ndims, - "Invalid position to squeeze: ", - dim, - ". Input tensor: ", - x->toString()); - - if (sizes[dim] == 1) { - std::vector to_squeeze(ndims, false); - to_squeeze[dim] = true; - return squeeze(x, to_squeeze); + std::vector all_false(to_squeeze.size(), false); + // If a squeeze does not perform a squeeze, create a no-op + if (to_squeeze == all_false) { + IrBuilder::create(LoadStoreOpType::Set, out, x); } else { - return set(x); + IrBuilder::create(x->container(), out, x, to_squeeze); } + + return out; } -TensorView* squeeze( - TensorView* x, - const std::vector& sizes, - const std::vector& dims) { +TensorView* squeeze(TensorView* x, const std::vector& to_squeeze) { NVF_ERROR(x != nullptr, "Input is invalid."); - const auto ndims = static_cast(x->domain()->noReductions().size()); + auto x_dom = x->domain()->noReductions(); + const auto ndims = static_cast(x_dom.size()); NVF_ERROR( - ndims == int(sizes.size()), - "Invalid sizes for squeeze: ", - sizes, + ndims == (int)to_squeeze.size(), + "Invalid to_squeeze for squeeze: ", + to_squeeze, ". Input tensor: ", x->toString()); - bool is_all_singleton_dimensions = true; - - std::vector to_squeeze(ndims); - for (auto dim : dims) { - if (dim < 0) { - dim = ndims + dim; + std::vector out_domain; + for (const auto idx : c10::irange(ndims)) { + auto id = x_dom[idx]; + if (to_squeeze[idx]) { + if (!id->isSymbolic()) { + NVF_CHECK( + id->isBroadcast(), + "Can not squeeze non-broadcasting dimension(s)."); + NVF_CHECK( + !id->hasExpandedExtent(), "Can not squeeze expanded dimension(s)."); + NVF_CHECK( + id->extent()->isConstScalar() && id->extent()->evaluate() == 1, + "Can not squeeze dimension(s) with size != 1."); + } + } else { + out_domain.push_back(id->cloneWithoutRFactor()); } + } - NVF_ERROR( - dim >= 0 && dim < ndims, - "Invalid position to squeeze: ", - dim, - ". Input tensor: ", - x->toString()); + auto out = IrBuilder::create( + IrBuilder::create( + out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)), + *x->getDataType()); - bool is_singleton_dim = (sizes[dim] == 1); - to_squeeze.at(dim) = is_singleton_dim; - is_all_singleton_dimensions &= is_singleton_dim; - } + IrBuilder::create(x->container(), out, x, to_squeeze); - if (is_all_singleton_dimensions) { - return squeeze(x, to_squeeze); - } else { - return set(x); - } + return out; } TensorView* unsqueeze(TensorView* x, int dim) { diff --git a/csrc/ops/alias.h b/csrc/ops/alias.h index 1f8e078f8f6..01e7ca16e74 100644 --- a/csrc/ops/alias.h +++ b/csrc/ops/alias.h @@ -43,16 +43,12 @@ TensorView* reshape(TensorView* x, const std::vector& new_sizes); TensorView* flatten(TensorView* x, int64_t start_dim = 0, int64_t end_dim = -1); -TensorView* squeeze(TensorView* x, const std::vector& to_squeeze); - -TensorView* squeeze(TensorView* x, const std::vector& sizes); - -TensorView* squeeze(TensorView* x, const std::vector& sizes, int dim); +// This implementation is specific to Pytorch where if you attempt to squeeze +// a non-broadcast dimension, the squeeze does not do anything to that +// dimension and does not trigger an error. +TensorView* squeeze(TensorView* x, const std::vector& dims); -TensorView* squeeze( - TensorView* x, - const std::vector& sizes, - const std::vector& dims); +TensorView* squeeze(TensorView* x, const std::vector& to_squeeze); TensorView* unsqueeze(TensorView* x, int dim); diff --git a/csrc/python_frontend/fusion_record.h b/csrc/python_frontend/fusion_record.h index 6926d57ce76..683ff8aaa48 100644 --- a/csrc/python_frontend/fusion_record.h +++ b/csrc/python_frontend/fusion_record.h @@ -582,14 +582,12 @@ struct SqueezeOpRecord : RecordFunctor { SqueezeOpRecord( std::vector _args, std::vector _outputs, - std::vector original_shape, std::vector dims) : RecordFunctor( std::move(_args), std::move(_outputs), "ops.squeeze", serde::RecordType::SqueezeOp), - original_shape_(std::move(original_shape)), dims_(std::move(dims)) {} ~SqueezeOpRecord() override = default; RecordFunctor* clone() final { @@ -597,69 +595,36 @@ struct SqueezeOpRecord : RecordFunctor { } //! Child specific hash function in lower 32 bits. - //! | 31 -------------- 16 | 15 -------------- 0 | - //! | Squeeze Dim hash | original_shape hash | + //! | 31 ------------------------------------- 0 | + //! | Squeeze Dim hash | size_t hash() const final { auto result = RecordFunctor::hash(); - size_t original_shape_hash = 0; - for (auto shape : original_shape_) { - original_shape_hash ^= static_cast(shape); - } size_t squeeze_dims_hash = 0; for (auto dim : dims_) { squeeze_dims_hash ^= static_cast(dim); } - squeeze_dims_hash = (squeeze_dims_hash & 0xffff) << 16; - return result | squeeze_dims_hash | (original_shape_hash & 0xffff); + result = result | (squeeze_dims_hash & 0xffffffff); + return result; } bool operator==(const RecordFunctor& other) const final { auto result = false; if (auto child_ptr = dynamic_cast(&other)) { - result = RecordFunctor::operator==(other); - if (result) { - result = (original_shape_.size() == child_ptr->original_shape_.size()); - if (result) { - for (size_t i = 0; i < dims_.size(); ++i) { - if (dims_[i] != child_ptr->dims_[i]) { - result = false; - break; - } - } - } - if (result) { - for (size_t i = 0; i < original_shape_.size(); ++i) { - if (original_shape_[i] != child_ptr->original_shape_[i]) { - result = false; - break; - } - } - } - } + result = RecordFunctor::operator==(other) && (dims_ == child_ptr->dims_); } return result; } void operator()(FusionState& fd) final { auto arg = fd.getFusionState(args_.at(0).index)->template as(); - auto output = squeeze(arg, original_shape_, dims_); + auto output = squeeze(arg, dims_); fd.setFusionState(outputs_.at(0).index, output); } void print(std::ostream& os, bool close_function = true) const final { RecordFunctor::print(os, false); - os << ", original_shape=["; + os << ", dims=["; bool first_arg = true; - for (auto shape : original_shape_) { - if (first_arg) { - first_arg = false; - } else { - os << ", "; - } - os << shape; - } - os << "], dims=["; - first_arg = true; for (auto dim : dims_) { if (first_arg) { first_arg = false; @@ -678,12 +643,10 @@ struct SqueezeOpRecord : RecordFunctor { flatbuffers::FlatBufferBuilder& builder) const final { return { serde::RecordData::Squeeze, - serde::CreateSqueezeDirect(builder, &original_shape_, &dims_).Union()}; + serde::CreateSqueezeDirect(builder, &dims_).Union()}; } private: - //! Represents the tensor dimensions of the input tensor. - std::vector original_shape_; //! Dimension to squeeze. std::vector dims_; }; diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index 6bad1a7894f..98e9ed7bb1f 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -2580,7 +2580,6 @@ void initNvFuserPythonBindings(PyObject* module) { "squeeze", [](FusionDefinition::Operators& self, Tensor arg, - std::vector& original_shape, std::vector& dims) -> Tensor { FUSER_PERF_SCOPE("Operators.squeeze"); NVF_CHECK( @@ -2590,12 +2589,10 @@ void initNvFuserPythonBindings(PyObject* module) { fd->defineRecord(new SqueezeOpRecord( {fd->recordingState(arg())}, {fd->recordingState(output())}, - std::move(original_shape), std::move(dims))); return output; }, py::arg("arg"), - py::arg("original_shape"), py::arg("dims"), py::return_value_policy::reference); nvf_ops.def( diff --git a/csrc/serde/fusion_cache.fbs b/csrc/serde/fusion_cache.fbs index f8106ec0b5e..de19a8faf62 100644 --- a/csrc/serde/fusion_cache.fbs +++ b/csrc/serde/fusion_cache.fbs @@ -302,7 +302,6 @@ table Slice { // Data for SqueezeOpRecord table Squeeze { - original_shape: [long]; squeeze_dims: [long]; } diff --git a/csrc/serde/fusion_record.cpp b/csrc/serde/fusion_record.cpp index b09ffcc88b2..c8d92e089f0 100644 --- a/csrc/serde/fusion_record.cpp +++ b/csrc/serde/fusion_record.cpp @@ -528,7 +528,6 @@ void RecordFunctorFactory::registerAllParsers() { return new python_frontend::SqueezeOpRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs()), - parseVector(data->original_shape()), parseVector(data->squeeze_dims())); }; registerParser(RecordType::SqueezeOp, deserializeSqueezeRecord); diff --git a/python_tests/pytest_input_generators.py b/python_tests/pytest_input_generators.py index 5c3038712ab..ed544f2f63f 100644 --- a/python_tests/pytest_input_generators.py +++ b/python_tests/pytest_input_generators.py @@ -1233,6 +1233,74 @@ def slice_error_generator( yield SampleInput(input_tensor, **es.kwargs), es.ex_type, es.ex_str +def squeeze_generator( + op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs +): + make_arg = partial( + make_tensor, device="cuda", dtype=dtype, requires_grad=requires_grad + ) + + # shape, squeeze_dims + cases = ( + ((5, 1, 1), (1, 2)), + ((5, 1, 1), (-2, -1)), + ((5, 1, 1), (2, 1)), + ((5, 1, 1), (-1, -2)), + ((1, 5, 1), (0, 2)), + ((1, 5, 1), (-3, -1)), + ((1, 1, 5), (0, 1)), + ((1, 1, 5), (-3, -2)), + ((5, 5, 5), ()), + ((1, 1, 1), ()), + ((1, 1, 1), (0, 1, 2)), + ((1, 1, 1), (-3, -2, -1)), + # No-op test cases + ((5, 5, 5), (0, 1, 2)), + ((5, 5, 5), (-3, -2, -1)), + ((), ()), + ) + + for shape, squeeze_dims in cases: + a = make_arg(shape) + yield SampleInput(a, squeeze_dims) + + +def squeeze_error_generator( + op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs +): + make_arg = partial( + make_tensor, device="cuda", dtype=dtype, requires_grad=requires_grad + ) + + # shape, start_indices, end_indices + out_of_range_cases = ( + ((5, 1, 1), (-4, -5)), # Dims are completely outside of tensor dims + ((5, 1, 1), (3, 4)), + ((5, 1, 1), (-3, -4)), # One dim in range, one dim out of range + ((5, 1, 1), (2, 3)), + ) + + error_type = RuntimeError + error_str = "Squeeze dim is outside of Tensor size!" + for shape, squeeze_dims in out_of_range_cases: + a = make_arg(shape) + yield SampleInput(a, squeeze_dims), error_type, error_str + + # shape, start_indices, end_indices + too_many_indices_cases = ( + ((5, 1, 1), (1, 2, 3, 4)), + ((5, 1, 1), (-1, -2, -3, -4)), + ((), (0,)), + ((), (-1,)), + ) + + error_type = RuntimeError + error_str = "The dims to squeeze must be <= the number of dims of the input tensor" + for shape, squeeze_dims in too_many_indices_cases: + a = make_arg(shape) + yield SampleInput(a, squeeze_dims), error_type, error_str + + def take_along_axis_generator( op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs ): diff --git a/python_tests/pytest_opinfos.py b/python_tests/pytest_opinfos.py index 6b5c9cce804..53d6f0f80b3 100644 --- a/python_tests/pytest_opinfos.py +++ b/python_tests/pytest_opinfos.py @@ -41,6 +41,8 @@ reshape_error_generator, slice_generator, slice_error_generator, + squeeze_generator, + squeeze_error_generator, take_along_axis_generator, take_along_axis_error_generator, tensor_size_error_generator, @@ -1020,6 +1022,19 @@ def torch_reshape_sym_fn(input_tensor, output_shaped_tensor): ) shape_ops.append(slice_opinfo) +squeeze_opinfo = OpInfo( + lambda fd: fd.ops.squeeze, + "squeeze", + sample_input_generator=squeeze_generator, + error_input_generator=squeeze_error_generator, + reference=torch.squeeze, + symbolic_parameter_list=( + ArgumentType.Symbolic, + ArgumentType.Constant, + ), +) +shape_ops.append(squeeze_opinfo) + take_along_axis_opinfo = OpInfo( lambda fd: fd.ops.take_along_axis, "take_along_dim", diff --git a/python_tests/test_python_frontend.py b/python_tests/test_python_frontend.py index 9383a4c4746..831529ddd63 100644 --- a/python_tests/test_python_frontend.py +++ b/python_tests/test_python_frontend.py @@ -851,14 +851,8 @@ def fusion_func(fd: FusionDefinition): t0 = fd.define_tensor(shape=[-1], contiguity=[True]) t1 = fd.define_tensor(sizes=t1_sizes, strides=[4, 1, 1]) t2 = fd.define_tensor(sizes=t2_sizes, strides=[4, 4, 1]) - t3 = fd.ops.squeeze(t1, t1_sizes, [0, -1]) - t4 = fd.ops.squeeze( - t2, - t2_sizes, - [ - -2, - ], - ) + t3 = fd.ops.squeeze(t1, [0, -1]) + t4 = fd.ops.squeeze(t2, [-2]) t5 = fd.ops.sum(t4, [0]) t6 = fd.ops.mul(t0, t3) t7 = fd.ops.mul(t6, t5) diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index bcff32895c6..72c5246a550 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -2547,7 +2547,7 @@ TEST_F(NVFuserTest, FusionSqueeze1_CUDA) { // [I, B] auto tv1 = sum(tv0, {1}, true); // [I] - auto tv2 = squeeze(tv1, std::vector{shape[0], 1}); + auto tv2 = squeeze(tv1, std::vector{1}); fusion.addOutput(tv2); NVF_CHECK(tv2->nDims() == 1, "Unexpected squeeze result: ", tv2->toString()); @@ -2557,7 +2557,7 @@ TEST_F(NVFuserTest, FusionSqueeze1_CUDA) { // tv3 has only one non-reduction axis. The extent of the first axis // is not one, so squeeze should fail. // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) - ASSERT_ANY_THROW(squeeze(tv3, std::vector{shape[0], 1})); + ASSERT_ANY_THROW(squeeze(tv3, std::vector{1})); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({10, 11}, options); diff --git a/test/test_resize.cpp b/test/test_resize.cpp index b844e832a88..d60e65592c3 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -2127,7 +2127,7 @@ TEST_F(ResizeTest, FusionSqueezeSymbolic) { // concretized to Broadcast // NOTE: squeeze interface should be updated to match reshape and friends, // accepting Val inputs - auto tv2 = squeeze(tv1, {20, 1}, 1); + auto tv2 = squeeze(tv1, std::vector{1}); // tv1 is of shape {0, 5} fusion->addOutput(tv2); diff --git a/version.txt b/version.txt index 845639eef26..9faa1b7a733 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.1.4 +0.1.5