diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h index 53457f02b8809..ff53a3c7a6522 100644 --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -263,6 +263,12 @@ constexpr bool empty(const T &RangeOrContainer) { return adl_begin(RangeOrContainer) == adl_end(RangeOrContainer); } +/// Returns true of the given range only contains a single element. +template bool hasSingleElement(ContainerTy &&c) { + auto it = std::begin(c), e = std::end(c); + return it != e && std::next(it) == e; +} + /// Return a range covering \p RangeOrContainer with the first N elements /// excluded. template auto drop_begin(T &&RangeOrContainer, size_t N) { @@ -1017,6 +1023,213 @@ detail::concat_range concat(RangeTs &&... Ranges) { std::forward(Ranges)...); } +/// A utility class used to implement an iterator that contains some base object +/// and an index. The iterator moves the index but keeps the base constant. +template +class indexed_accessor_iterator + : public llvm::iterator_facade_base { +public: + ptrdiff_t operator-(const indexed_accessor_iterator &rhs) const { + assert(base == rhs.base && "incompatible iterators"); + return index - rhs.index; + } + bool operator==(const indexed_accessor_iterator &rhs) const { + return base == rhs.base && index == rhs.index; + } + bool operator<(const indexed_accessor_iterator &rhs) const { + assert(base == rhs.base && "incompatible iterators"); + return index < rhs.index; + } + + DerivedT &operator+=(ptrdiff_t offset) { + this->index += offset; + return static_cast(*this); + } + DerivedT &operator-=(ptrdiff_t offset) { + this->index -= offset; + return static_cast(*this); + } + + /// Returns the current index of the iterator. + ptrdiff_t getIndex() const { return index; } + + /// Returns the current base of the iterator. + const BaseT &getBase() const { return base; } + +protected: + indexed_accessor_iterator(BaseT base, ptrdiff_t index) + : base(base), index(index) {} + BaseT base; + ptrdiff_t index; +}; + +namespace detail { +/// The class represents the base of a range of indexed_accessor_iterators. It +/// provides support for many different range functionalities, e.g. +/// drop_front/slice/etc.. Derived range classes must implement the following +/// static methods: +/// * ReferenceT dereference_iterator(const BaseT &base, ptrdiff_t index) +/// - Dereference an iterator pointing to the base object at the given +/// index. +/// * BaseT offset_base(const BaseT &base, ptrdiff_t index) +/// - Return a new base that is offset from the provide base by 'index' +/// elements. +template +class indexed_accessor_range_base { +public: + using RangeBaseT = + indexed_accessor_range_base; + + /// An iterator element of this range. + class iterator : public indexed_accessor_iterator { + public: + // Index into this iterator, invoking a static method on the derived type. + ReferenceT operator*() const { + return DerivedT::dereference_iterator(this->getBase(), this->getIndex()); + } + + private: + iterator(BaseT owner, ptrdiff_t curIndex) + : indexed_accessor_iterator( + owner, curIndex) {} + + /// Allow access to the constructor. + friend indexed_accessor_range_base; + }; + + indexed_accessor_range_base(iterator begin, iterator end) + : base(DerivedT::offset_base(begin.getBase(), begin.getIndex())), + count(end.getIndex() - begin.getIndex()) {} + indexed_accessor_range_base(const iterator_range &range) + : indexed_accessor_range_base(range.begin(), range.end()) {} + indexed_accessor_range_base(BaseT base, ptrdiff_t count) + : base(base), count(count) {} + + iterator begin() const { return iterator(base, 0); } + iterator end() const { return iterator(base, count); } + ReferenceT operator[](unsigned index) const { + assert(index < size() && "invalid index for value range"); + return DerivedT::dereference_iterator(base, index); + } + + /// Compare this range with another. + template bool operator==(const OtherT &other) { + return size() == std::distance(other.begin(), other.end()) && + std::equal(begin(), end(), other.begin()); + } + + /// Return the size of this range. + size_t size() const { return count; } + + /// Return if the range is empty. + bool empty() const { return size() == 0; } + + /// Drop the first N elements, and keep M elements. + DerivedT slice(size_t n, size_t m) const { + assert(n + m <= size() && "invalid size specifiers"); + return DerivedT(DerivedT::offset_base(base, n), m); + } + + /// Drop the first n elements. + DerivedT drop_front(size_t n = 1) const { + assert(size() >= n && "Dropping more elements than exist"); + return slice(n, size() - n); + } + /// Drop the last n elements. + DerivedT drop_back(size_t n = 1) const { + assert(size() >= n && "Dropping more elements than exist"); + return DerivedT(base, size() - n); + } + + /// Take the first n elements. + DerivedT take_front(size_t n = 1) const { + return n < size() ? drop_back(size() - n) + : static_cast(*this); + } + + /// Take the last n elements. + DerivedT take_back(size_t n = 1) const { + return n < size() ? drop_front(size() - n) + : static_cast(*this); + } + + /// Allow conversion to any type accepting an iterator_range. + template >::value>> + operator RangeT() const { + return RangeT(iterator_range(*this)); + } + +protected: + indexed_accessor_range_base(const indexed_accessor_range_base &) = default; + indexed_accessor_range_base(indexed_accessor_range_base &&) = default; + indexed_accessor_range_base & + operator=(const indexed_accessor_range_base &) = default; + + /// The base that owns the provided range of values. + BaseT base; + /// The size from the owning range. + ptrdiff_t count; +}; +} // end namespace detail + +/// This class provides an implementation of a range of +/// indexed_accessor_iterators where the base is not indexable. Ranges with +/// bases that are offsetable should derive from indexed_accessor_range_base +/// instead. Derived range classes are expected to implement the following +/// static method: +/// * ReferenceT dereference(const BaseT &base, ptrdiff_t index) +/// - Dereference an iterator pointing to a parent base at the given index. +template +class indexed_accessor_range + : public detail::indexed_accessor_range_base< + DerivedT, std::pair, T, PointerT, ReferenceT> { +public: + indexed_accessor_range(BaseT base, ptrdiff_t startIndex, ptrdiff_t count) + : detail::indexed_accessor_range_base< + DerivedT, std::pair, T, PointerT, ReferenceT>( + std::make_pair(base, startIndex), count) {} + using detail::indexed_accessor_range_base< + DerivedT, std::pair, T, PointerT, + ReferenceT>::indexed_accessor_range_base; + + /// Returns the current base of the range. + const BaseT &getBase() const { return this->base.first; } + + /// Returns the current start index of the range. + ptrdiff_t getStartIndex() const { return this->base.second; } + + /// See `detail::indexed_accessor_range_base` for details. + static std::pair + offset_base(const std::pair &base, ptrdiff_t index) { + // We encode the internal base as a pair of the derived base and a start + // index into the derived base. + return std::make_pair(base.first, base.second + index); + } + /// See `detail::indexed_accessor_range_base` for details. + static ReferenceT + dereference_iterator(const std::pair &base, + ptrdiff_t index) { + return DerivedT::dereference(base.first, base.second + index); + } +}; + +/// Given a container of pairs, return a range over the second elements. +template auto make_second_range(ContainerTy &&c) { + return llvm::map_range( + std::forward(c), + [](decltype((*std::begin(c))) elt) -> decltype((elt.second)) { + return elt.second; + }); +} + //===----------------------------------------------------------------------===// // Extra additions to //===----------------------------------------------------------------------===// diff --git a/llvm/unittests/Support/CMakeLists.txt b/llvm/unittests/Support/CMakeLists.txt index 0c32113338365..b9eeba165c96b 100644 --- a/llvm/unittests/Support/CMakeLists.txt +++ b/llvm/unittests/Support/CMakeLists.txt @@ -40,6 +40,7 @@ add_llvm_unittest(SupportTests FormatVariadicTest.cpp GlobPatternTest.cpp Host.cpp + IndexedAccessorTest.cpp ItaniumManglingCanonicalizerTest.cpp JSONTest.cpp KnownBitsTest.cpp diff --git a/mlir/unittests/Support/IndexedAccessorTest.cpp b/llvm/unittests/Support/IndexedAccessorTest.cpp similarity index 92% rename from mlir/unittests/Support/IndexedAccessorTest.cpp rename to llvm/unittests/Support/IndexedAccessorTest.cpp index dc08270771fc9..9981e91df100e 100644 --- a/mlir/unittests/Support/IndexedAccessorTest.cpp +++ b/llvm/unittests/Support/IndexedAccessorTest.cpp @@ -6,12 +6,12 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Support/STLExtras.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "gmock/gmock.h" -using namespace mlir; -using namespace mlir::detail; +using namespace llvm; +using namespace llvm::detail; namespace { /// Simple indexed accessor range that wraps an array. @@ -24,7 +24,7 @@ struct ArrayIndexedAccessorRange using indexed_accessor_range, T *, T>::indexed_accessor_range; - /// See `indexed_accessor_range` for details. + /// See `llvm::indexed_accessor_range` for details. static T &dereference(T *data, ptrdiff_t index) { return data[index]; } }; } // end anonymous namespace diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h index 55911cf10b7bf..3b5a82d239b95 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -290,16 +290,16 @@ class StructType : public Type::TypeBase { private: using RangeBaseT::RangeBaseT; - /// See `mlir::detail::indexed_accessor_range_base` for details. + /// See `llvm::detail::indexed_accessor_range_base` for details. static const Type *offset_base(const Type *object, ptrdiff_t index) { return object + index; } - /// See `mlir::detail::indexed_accessor_range_base` for details. + /// See `llvm::detail::indexed_accessor_range_base` for details. static Type dereference_iterator(const Type *object, ptrdiff_t index) { return object[index]; } diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 51b6edf8dddab..ab8cc6ee0a00a 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -648,13 +648,14 @@ using DenseIterPtrAndSplat = template class DenseElementIndexedIteratorImpl - : public indexed_accessor_iterator { + : public llvm::indexed_accessor_iterator { protected: DenseElementIndexedIteratorImpl(const char *data, bool isSplat, size_t dataIndex) - : indexed_accessor_iterator({data, isSplat}, dataIndex) {} + : llvm::indexed_accessor_iterator({data, isSplat}, + dataIndex) {} /// Return the current index for this iterator, adjusted for the case of a /// splat. @@ -746,8 +747,9 @@ class DenseElementsAttr /// A utility iterator that allows walking over the internal Attribute values /// of a DenseElementsAttr. class AttributeElementIterator - : public indexed_accessor_iterator { + : public llvm::indexed_accessor_iterator { public: /// Accesses the Attribute value at this iterator position. Attribute operator*() const; diff --git a/mlir/include/mlir/IR/BlockSupport.h b/mlir/include/mlir/IR/BlockSupport.h index 8a505959e352e..3c246749c584c 100644 --- a/mlir/include/mlir/IR/BlockSupport.h +++ b/mlir/include/mlir/IR/BlockSupport.h @@ -54,19 +54,19 @@ class PredecessorIterator final /// This class implements the successor iterators for Block. class SuccessorRange final - : public detail::indexed_accessor_range_base { + : public llvm::detail::indexed_accessor_range_base< + SuccessorRange, BlockOperand *, Block *, Block *, Block *> { public: using RangeBaseT::RangeBaseT; SuccessorRange(Block *block); SuccessorRange(Operation *term); private: - /// See `detail::indexed_accessor_range_base` for details. + /// See `llvm::detail::indexed_accessor_range_base` for details. static BlockOperand *offset_base(BlockOperand *object, ptrdiff_t index) { return object + index; } - /// See `detail::indexed_accessor_range_base` for details. + /// See `llvm::detail::indexed_accessor_range_base` for details. static Block *dereference_iterator(BlockOperand *object, ptrdiff_t index) { return object[index].get(); } diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index cad162cbb3f80..fbbd05e7b49de 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -113,7 +113,7 @@ class OpAsmPrinter { void printArrowTypeList(TypeRange &&types) { auto &os = getStream() << " -> "; - bool wrapped = !has_single_element(types) || + bool wrapped = !llvm::hasSingleElement(types) || (*types.begin()).template isa(); if (wrapped) os << '('; diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index 6efdff2fb5a08..35c5b017ab89c 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -558,7 +558,7 @@ class OpPrintingFlags { /// suitable for a more derived type (e.g. ArrayRef) or a template range /// parameter. class TypeRange - : public detail::indexed_accessor_range_base< + : public llvm::detail::indexed_accessor_range_base< TypeRange, llvm::PointerUnion, Type, Type, Type> { @@ -589,9 +589,9 @@ class TypeRange /// * A pointer to the first element of an array of operands. using OwnerT = llvm::PointerUnion; - /// See `detail::indexed_accessor_range_base` for details. + /// See `llvm::detail::indexed_accessor_range_base` for details. static OwnerT offset_base(OwnerT object, ptrdiff_t index); - /// See `detail::indexed_accessor_range_base` for details. + /// See `llvm::detail::indexed_accessor_range_base` for details. static Type dereference_iterator(OwnerT object, ptrdiff_t index); /// Allow access to `offset_base` and `dereference_iterator`. @@ -640,9 +640,8 @@ inline bool operator==(ArrayRef lhs, const ValueTypeRange &rhs) { // OperandRange /// This class implements the operand iterators for the Operation class. -class OperandRange final - : public detail::indexed_accessor_range_base { +class OperandRange final : public llvm::detail::indexed_accessor_range_base< + OperandRange, OpOperand *, Value, Value, Value> { public: using RangeBaseT::RangeBaseT; OperandRange(Operation *op); @@ -658,11 +657,11 @@ class OperandRange final unsigned getBeginOperandIndex() const; private: - /// See `detail::indexed_accessor_range_base` for details. + /// See `llvm::detail::indexed_accessor_range_base` for details. static OpOperand *offset_base(OpOperand *object, ptrdiff_t index) { return object + index; } - /// See `detail::indexed_accessor_range_base` for details. + /// See `llvm::detail::indexed_accessor_range_base` for details. static Value dereference_iterator(OpOperand *object, ptrdiff_t index) { return object[index].get(); } @@ -676,8 +675,8 @@ class OperandRange final /// This class implements the result iterators for the Operation class. class ResultRange final - : public indexed_accessor_range { + : public llvm::indexed_accessor_range { public: using indexed_accessor_range::indexed_accessor_range; @@ -690,12 +689,12 @@ class ResultRange final auto getType() const { return getTypes(); } private: - /// See `indexed_accessor_range` for details. + /// See `llvm::indexed_accessor_range` for details. static OpResult dereference(Operation *op, ptrdiff_t index); /// Allow access to `dereference_iterator`. - friend indexed_accessor_range; + friend llvm::indexed_accessor_range; }; //===----------------------------------------------------------------------===// @@ -730,7 +729,7 @@ struct ValueRangeOwner { /// suitable for a more derived type (e.g. ArrayRef) or a template range /// parameter. class ValueRange final - : public detail::indexed_accessor_range_base< + : public llvm::detail::indexed_accessor_range_base< ValueRange, detail::ValueRangeOwner, Value, Value, Value> { public: using RangeBaseT::RangeBaseT; @@ -762,9 +761,9 @@ class ValueRange final private: using OwnerT = detail::ValueRangeOwner; - /// See `detail::indexed_accessor_range_base` for details. + /// See `llvm::detail::indexed_accessor_range_base` for details. static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index); - /// See `detail::indexed_accessor_range_base` for details. + /// See `llvm::detail::indexed_accessor_range_base` for details. static Value dereference_iterator(const OwnerT &owner, ptrdiff_t index); /// Allow access to `offset_base` and `dereference_iterator`. diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h index 0ea9df4eca8fe..482c9c03b70a9 100644 --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -157,7 +157,7 @@ class Region { /// suitable for a more derived type (e.g. ArrayRef) or a template range /// parameter. class RegionRange - : public detail::indexed_accessor_range_base< + : public llvm::detail::indexed_accessor_range_base< RegionRange, PointerUnion *>, Region *, Region *, Region *> { /// The type representing the owner of this range. This is either a list of @@ -178,9 +178,9 @@ class RegionRange RegionRange(ArrayRef> regions); private: - /// See `detail::indexed_accessor_range_base` for details. + /// See `llvm::detail::indexed_accessor_range_base` for details. static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index); - /// See `detail::indexed_accessor_range_base` for details. + /// See `llvm::detail::indexed_accessor_range_base` for details. static Region *dereference_iterator(const OwnerT &owner, ptrdiff_t index); /// Allow access to `offset_base` and `dereference_iterator`. diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h index 2375f6adc81dd..c07c6c5e40a5a 100644 --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -148,7 +148,7 @@ class IRMultiObjectWithUseList : public IRObjectWithUseList { return {use_begin(value), use_end(value)}; } bool hasOneUse(ValueType value) const { - return mlir::has_single_element(getUses(value)); + return llvm::hasSingleElement(getUses(value)); } bool use_empty(ValueType value) const { return use_begin(value) == use_end(value); diff --git a/mlir/include/mlir/Support/STLExtras.h b/mlir/include/mlir/Support/STLExtras.h index 2279a35dc7a8e..88636f6481b9e 100644 --- a/mlir/include/mlir/Support/STLExtras.h +++ b/mlir/include/mlir/Support/STLExtras.h @@ -88,243 +88,6 @@ inline void interleaveComma(const Container &c, raw_ostream &os) { interleaveComma(c, os, [&](const T &a) { os << a; }); } -//===----------------------------------------------------------------------===// -// Extra additions to -//===----------------------------------------------------------------------===// - -/// A utility class used to implement an iterator that contains some base object -/// and an index. The iterator moves the index but keeps the base constant. -template -class indexed_accessor_iterator - : public llvm::iterator_facade_base { -public: - ptrdiff_t operator-(const indexed_accessor_iterator &rhs) const { - assert(base == rhs.base && "incompatible iterators"); - return index - rhs.index; - } - bool operator==(const indexed_accessor_iterator &rhs) const { - return base == rhs.base && index == rhs.index; - } - bool operator<(const indexed_accessor_iterator &rhs) const { - assert(base == rhs.base && "incompatible iterators"); - return index < rhs.index; - } - - DerivedT &operator+=(ptrdiff_t offset) { - this->index += offset; - return static_cast(*this); - } - DerivedT &operator-=(ptrdiff_t offset) { - this->index -= offset; - return static_cast(*this); - } - - /// Returns the current index of the iterator. - ptrdiff_t getIndex() const { return index; } - - /// Returns the current base of the iterator. - const BaseT &getBase() const { return base; } - -protected: - indexed_accessor_iterator(BaseT base, ptrdiff_t index) - : base(base), index(index) {} - BaseT base; - ptrdiff_t index; -}; - -namespace detail { -/// The class represents the base of a range of indexed_accessor_iterators. It -/// provides support for many different range functionalities, e.g. -/// drop_front/slice/etc.. Derived range classes must implement the following -/// static methods: -/// * ReferenceT dereference_iterator(const BaseT &base, ptrdiff_t index) -/// - Dereference an iterator pointing to the base object at the given -/// index. -/// * BaseT offset_base(const BaseT &base, ptrdiff_t index) -/// - Return a new base that is offset from the provide base by 'index' -/// elements. -template -class indexed_accessor_range_base { -public: - using RangeBaseT = - indexed_accessor_range_base; - - /// An iterator element of this range. - class iterator : public indexed_accessor_iterator { - public: - // Index into this iterator, invoking a static method on the derived type. - ReferenceT operator*() const { - return DerivedT::dereference_iterator(this->getBase(), this->getIndex()); - } - - private: - iterator(BaseT owner, ptrdiff_t curIndex) - : indexed_accessor_iterator( - owner, curIndex) {} - - /// Allow access to the constructor. - friend indexed_accessor_range_base; - }; - - indexed_accessor_range_base(iterator begin, iterator end) - : base(DerivedT::offset_base(begin.getBase(), begin.getIndex())), - count(end.getIndex() - begin.getIndex()) {} - indexed_accessor_range_base(const iterator_range &range) - : indexed_accessor_range_base(range.begin(), range.end()) {} - indexed_accessor_range_base(BaseT base, ptrdiff_t count) - : base(base), count(count) {} - - iterator begin() const { return iterator(base, 0); } - iterator end() const { return iterator(base, count); } - ReferenceT operator[](unsigned index) const { - assert(index < size() && "invalid index for value range"); - return DerivedT::dereference_iterator(base, index); - } - - /// Compare this range with another. - template bool operator==(const OtherT &other) { - return size() == llvm::size(other) && - std::equal(begin(), end(), other.begin()); - } - - /// Return the size of this range. - size_t size() const { return count; } - - /// Return if the range is empty. - bool empty() const { return size() == 0; } - - /// Drop the first N elements, and keep M elements. - DerivedT slice(size_t n, size_t m) const { - assert(n + m <= size() && "invalid size specifiers"); - return DerivedT(DerivedT::offset_base(base, n), m); - } - - /// Drop the first n elements. - DerivedT drop_front(size_t n = 1) const { - assert(size() >= n && "Dropping more elements than exist"); - return slice(n, size() - n); - } - /// Drop the last n elements. - DerivedT drop_back(size_t n = 1) const { - assert(size() >= n && "Dropping more elements than exist"); - return DerivedT(base, size() - n); - } - - /// Take the first n elements. - DerivedT take_front(size_t n = 1) const { - return n < size() ? drop_back(size() - n) - : static_cast(*this); - } - - /// Take the last n elements. - DerivedT take_back(size_t n = 1) const { - return n < size() ? drop_front(size() - n) - : static_cast(*this); - } - - /// Allow conversion to SmallVector if necessary. - /// TODO(riverriddle) Remove this when SmallVector accepts different range - /// types in its constructor. - template operator SmallVector() const { - return {begin(), end()}; - } - -protected: - indexed_accessor_range_base(const indexed_accessor_range_base &) = default; - indexed_accessor_range_base(indexed_accessor_range_base &&) = default; - indexed_accessor_range_base & - operator=(const indexed_accessor_range_base &) = default; - - /// The base that owns the provided range of values. - BaseT base; - /// The size from the owning range. - ptrdiff_t count; -}; -} // end namespace detail - -/// This class provides an implementation of a range of -/// indexed_accessor_iterators where the base is not indexable. Ranges with -/// bases that are offsetable should derive from indexed_accessor_range_base -/// instead. Derived range classes are expected to implement the following -/// static method: -/// * ReferenceT dereference(const BaseT &base, ptrdiff_t index) -/// - Dereference an iterator pointing to a parent base at the given index. -template -class indexed_accessor_range - : public detail::indexed_accessor_range_base< - DerivedT, std::pair, T, PointerT, ReferenceT> { -public: - indexed_accessor_range(BaseT base, ptrdiff_t startIndex, ptrdiff_t count) - : detail::indexed_accessor_range_base< - DerivedT, std::pair, T, PointerT, ReferenceT>( - std::make_pair(base, startIndex), count) {} - using detail::indexed_accessor_range_base< - DerivedT, std::pair, T, PointerT, - ReferenceT>::indexed_accessor_range_base; - - /// Returns the current base of the range. - const BaseT &getBase() const { return this->base.first; } - - /// Returns the current start index of the range. - ptrdiff_t getStartIndex() const { return this->base.second; } - - /// See `detail::indexed_accessor_range_base` for details. - static std::pair - offset_base(const std::pair &base, ptrdiff_t index) { - // We encode the internal base as a pair of the derived base and a start - // index into the derived base. - return std::make_pair(base.first, base.second + index); - } - /// See `detail::indexed_accessor_range_base` for details. - static ReferenceT - dereference_iterator(const std::pair &base, - ptrdiff_t index) { - return DerivedT::dereference(base.first, base.second + index); - } -}; - -/// Given a container of pairs, return a range over the second elements. -template auto make_second_range(ContainerTy &&c) { - return llvm::map_range( - std::forward(c), - [](decltype((*std::begin(c))) elt) -> decltype((elt.second)) { - return elt.second; - }); -} - -/// A range class that repeats a specific value for a set number of times. -template -class RepeatRange - : public detail::indexed_accessor_range_base, T, const T> { -public: - using detail::indexed_accessor_range_base< - RepeatRange, T, const T>::indexed_accessor_range_base; - - /// Given that we are repeating a specific value, we can simply return that - /// value when offsetting the base or dereferencing the iterator. - static T offset_base(const T &val, ptrdiff_t) { return val; } - static const T &dereference_iterator(const T &val, ptrdiff_t) { return val; } -}; - -/// Make a range that repeats the given value 'n' times. -template -RepeatRange make_repeated_range(const ValueTy &value, size_t n) { - return RepeatRange(value, n); -} - -/// Returns true of the given range only contains a single element. -template bool has_single_element(ContainerTy &&c) { - auto it = std::begin(c), e = std::end(c); - return it != e && std::next(it) == e; -} - } // end namespace mlir #endif // MLIR_SUPPORT_STLEXTRAS_H diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 0d03dd70039da..8fb03fbf77384 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1407,7 +1407,7 @@ struct AffineForEmptyLoopFolder : public OpRewritePattern { LogicalResult matchAndRewrite(AffineForOp forOp, PatternRewriter &rewriter) const override { // Check that the body only contains a terminator. - if (!has_single_element(*forOp.getBody())) + if (!llvm::hasSingleElement(*forOp.getBody())) return failure(); rewriter.eraseOp(forOp); return success(); @@ -1576,7 +1576,8 @@ struct SimplifyDeadElse : public OpRewritePattern { LogicalResult matchAndRewrite(AffineIfOp ifOp, PatternRewriter &rewriter) const override { - if (ifOp.elseRegion().empty() || !has_single_element(*ifOp.getElseBlock())) + if (ifOp.elseRegion().empty() || + !llvm::hasSingleElement(*ifOp.getElseBlock())) return failure(); rewriter.startRootUpdate(ifOp); diff --git a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp index 24dd018e0c447..b5bf47ccc21f3 100644 --- a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp @@ -138,7 +138,7 @@ static void insertCopies(Region ®ion, Location loc, Value from, Value to) { (void)toType; assert(fromType.getShape() == toType.getShape()); assert(fromType.getRank() != 0); - assert(has_single_element(region) && + assert(llvm::hasSingleElement(region) && "unstructured control flow not supported"); OpBuilder builder(region.getContext()); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index c572be4d132e3..eacdf23b44fd0 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -31,7 +31,7 @@ using namespace mlir::loop; Optional RegionMatcher::matchAsScalarBinaryOp(GenericOp op) { auto ®ion = op.region(); - if (!has_single_element(region)) + if (!llvm::hasSingleElement(region)) return llvm::None; Block &block = region.front(); @@ -41,7 +41,7 @@ RegionMatcher::matchAsScalarBinaryOp(GenericOp op) { return llvm::None; auto &ops = block.getOperations(); - if (!has_single_element(block.without_terminator())) + if (!llvm::hasSingleElement(block.without_terminator())) return llvm::None; using mlir::matchers::m_Val; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 9a1bb68dd1a6f..1d1d4a6a3f975 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1956,7 +1956,7 @@ static void print(spirv::LoopOp loopOp, OpAsmPrinter &printer) { /// given `dstBlock`. static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) { // Check that there is only one op in the `srcBlock`. - if (!has_single_element(srcBlock)) + if (!llvm::hasSingleElement(srcBlock)) return false; auto branchOp = dyn_cast(srcBlock.back()); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index bbe04cf6e1f8a..df80e3e15b1ef 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -499,7 +499,7 @@ struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern { // Check that the successor block has a single predecessor. Block *succ = op.getDest(); Block *opParent = op.getOperation()->getBlock(); - if (succ == opParent || !has_single_element(succ->getPredecessors())) + if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors())) return failure(); // Merge the successor into the current block and erase the branch. diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index daa5f9c58fc66..657df5752adee 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -554,8 +554,8 @@ static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { /// Constructs a new iterator. DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( DenseElementsAttr attr, size_t index) - : indexed_accessor_iterator( + : llvm::indexed_accessor_iterator( attr.getAsOpaquePointer(), index) {} /// Accesses the Attribute value at this iterator position. diff --git a/mlir/lib/IR/Module.cpp b/mlir/lib/IR/Module.cpp index b01b3782014e0..b441462821b68 100644 --- a/mlir/lib/IR/Module.cpp +++ b/mlir/lib/IR/Module.cpp @@ -73,7 +73,7 @@ LogicalResult ModuleOp::verify() { auto &bodyRegion = getOperation()->getRegion(0); // The body must contain a single basic block. - if (!has_single_element(bodyRegion)) + if (!llvm::hasSingleElement(bodyRegion)) return emitOpError("expected body region to have a single block"); // Check that the body has no block arguments. diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp index ef949323d18d3..1b5f78ce1fe43 100644 --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -158,7 +158,7 @@ TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) { this->base = owner.ptr.get(); } -/// See `detail::indexed_accessor_range_base` for details. +/// See `llvm::detail::indexed_accessor_range_base` for details. TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) { if (auto *value = object.dyn_cast()) return {value + index}; @@ -166,7 +166,7 @@ TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) { return {operand + index}; return {object.dyn_cast() + index}; } -/// See `detail::indexed_accessor_range_base` for details. +/// See `llvm::detail::indexed_accessor_range_base` for details. Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) { if (auto *value = object.dyn_cast()) return (value + index)->getType(); @@ -198,7 +198,7 @@ ArrayRef ResultRange::getTypes() const { return getBase()->getResultTypes(); } -/// See `indexed_accessor_range` for details. +/// See `llvm::indexed_accessor_range` for details. OpResult ResultRange::dereference(Operation *op, ptrdiff_t index) { return op->getResult(index); } @@ -215,7 +215,7 @@ ValueRange::ValueRange(ResultRange values) {values.getBase(), static_cast(values.getStartIndex())}, values.size()) {} -/// See `detail::indexed_accessor_range_base` for details. +/// See `llvm::detail::indexed_accessor_range_base` for details. ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner, ptrdiff_t index) { if (auto *value = owner.ptr.dyn_cast()) @@ -225,7 +225,7 @@ ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner, Operation *operation = reinterpret_cast(owner.ptr.get()); return {operation, owner.startIndex + static_cast(index)}; } -/// See `detail::indexed_accessor_range_base` for details. +/// See `llvm::detail::indexed_accessor_range_base` for details. Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) { if (auto *value = owner.ptr.dyn_cast()) return value[index]; diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp index ae4d15fabb289..f039f43a0ac98 100644 --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -214,14 +214,14 @@ RegionRange::RegionRange(MutableArrayRef regions) RegionRange::RegionRange(ArrayRef> regions) : RegionRange(regions.data(), regions.size()) {} -/// See `detail::indexed_accessor_range_base` for details. +/// See `llvm::detail::indexed_accessor_range_base` for details. RegionRange::OwnerT RegionRange::offset_base(const OwnerT &owner, ptrdiff_t index) { if (auto *operand = owner.dyn_cast *>()) return operand + index; return &owner.get()[index]; } -/// See `detail::indexed_accessor_range_base` for details. +/// See `llvm::detail::indexed_accessor_range_base` for details. Region *RegionRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) { if (auto *operand = owner.dyn_cast *>()) diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index f6fe0cef94f8f..2b1d99b0a363f 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -79,7 +79,7 @@ SymbolTable::SymbolTable(Operation *symbolTableOp) "expected operation to have SymbolTable trait"); assert(symbolTableOp->getNumRegions() == 1 && "expected operation to have a single region"); - assert(has_single_element(symbolTableOp->getRegion(0)) && + assert(llvm::hasSingleElement(symbolTableOp->getRegion(0)) && "expected operation to have a single block"); for (auto &op : symbolTableOp->getRegion(0).front()) { @@ -290,7 +290,7 @@ LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) { if (op->getNumRegions() != 1) return op->emitOpError() << "Operations with a 'SymbolTable' must have exactly one region"; - if (!has_single_element(op->getRegion(0))) + if (!llvm::hasSingleElement(op->getRegion(0))) return op->emitOpError() << "Operations with a 'SymbolTable' must have exactly one block"; diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 15f5c25a08a82..72f889e2315a0 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -414,7 +414,7 @@ LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp, return promoteIfSingleIteration(forOp); // Nothing in the loop body other than the terminator. - if (has_single_element(forOp.getBody()->getOperations())) + if (llvm::hasSingleElement(forOp.getBody()->getOperations())) return success(); // Loops where the lower bound is a max expression isn't supported for @@ -538,7 +538,7 @@ LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, return promoteIfSingleIteration(forOp); // Nothing in the loop body other than the terminator. - if (has_single_element(forOp.getBody()->getOperations())) + if (llvm::hasSingleElement(forOp.getBody()->getOperations())) return success(); // Loops where both lower and upper bounds are multi-result maps won't be diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt index 28e79b4f75746..79a5297b66e01 100644 --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -10,5 +10,4 @@ add_subdirectory(Dialect) add_subdirectory(IR) add_subdirectory(Pass) add_subdirectory(SDBM) -add_subdirectory(Support) add_subdirectory(TableGen) diff --git a/mlir/unittests/Support/CMakeLists.txt b/mlir/unittests/Support/CMakeLists.txt deleted file mode 100644 index d7796e63f26ae..0000000000000 --- a/mlir/unittests/Support/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -add_mlir_unittest(MLIRSupportTests - IndexedAccessorTest.cpp -) - -target_link_libraries(MLIRSupportTests - PRIVATE MLIRSupport)