Skip to content

Commit

Permalink
[llvm][STLExtras] Move various iterator/range utilities from MLIR to …
Browse files Browse the repository at this point in the history
…LLVM

This revision moves the various range utilities present in MLIR to LLVM to enable greater reuse. This revision moves the following utilities:

* indexed_accessor_*
This is set of utility iterator/range base classes that allow for building a range class where the iterators are represented by an object+index pair.

* make_second_range
Given a range of pairs, returns a range iterating over the `second` elements.

* hasSingleElement
Returns if the given range has 1 element. size() == 1 checks end up being very common, but size() is not always O(1) (e.g., ilist). This method provides O(1) checks for those cases.

Differential Revision: https://reviews.llvm.org/D78064
  • Loading branch information
River707 committed Apr 14, 2020
1 parent 8cbe371 commit 204c3b5
Show file tree
Hide file tree
Showing 24 changed files with 275 additions and 303 deletions.
213 changes: 213 additions & 0 deletions llvm/include/llvm/ADT/STLExtras.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ constexpr bool empty(const T &RangeOrContainer) {
return adl_begin(RangeOrContainer) == adl_end(RangeOrContainer);
}

/// Returns true of the given range only contains a single element.
template <typename ContainerTy> bool hasSingleElement(ContainerTy &&c) {
auto it = std::begin(c), e = std::end(c);
return it != e && std::next(it) == e;
}

/// Return a range covering \p RangeOrContainer with the first N elements
/// excluded.
template <typename T> auto drop_begin(T &&RangeOrContainer, size_t N) {
Expand Down Expand Up @@ -1017,6 +1023,213 @@ detail::concat_range<ValueT, RangeTs...> concat(RangeTs &&... Ranges) {
std::forward<RangeTs>(Ranges)...);
}

/// A utility class used to implement an iterator that contains some base object
/// and an index. The iterator moves the index but keeps the base constant.
template <typename DerivedT, typename BaseT, typename T,
typename PointerT = T *, typename ReferenceT = T &>
class indexed_accessor_iterator
: public llvm::iterator_facade_base<DerivedT,
std::random_access_iterator_tag, T,
std::ptrdiff_t, PointerT, ReferenceT> {
public:
ptrdiff_t operator-(const indexed_accessor_iterator &rhs) const {
assert(base == rhs.base && "incompatible iterators");
return index - rhs.index;
}
bool operator==(const indexed_accessor_iterator &rhs) const {
return base == rhs.base && index == rhs.index;
}
bool operator<(const indexed_accessor_iterator &rhs) const {
assert(base == rhs.base && "incompatible iterators");
return index < rhs.index;
}

DerivedT &operator+=(ptrdiff_t offset) {
this->index += offset;
return static_cast<DerivedT &>(*this);
}
DerivedT &operator-=(ptrdiff_t offset) {
this->index -= offset;
return static_cast<DerivedT &>(*this);
}

/// Returns the current index of the iterator.
ptrdiff_t getIndex() const { return index; }

/// Returns the current base of the iterator.
const BaseT &getBase() const { return base; }

protected:
indexed_accessor_iterator(BaseT base, ptrdiff_t index)
: base(base), index(index) {}
BaseT base;
ptrdiff_t index;
};

namespace detail {
/// The class represents the base of a range of indexed_accessor_iterators. It
/// provides support for many different range functionalities, e.g.
/// drop_front/slice/etc.. Derived range classes must implement the following
/// static methods:
/// * ReferenceT dereference_iterator(const BaseT &base, ptrdiff_t index)
/// - Dereference an iterator pointing to the base object at the given
/// index.
/// * BaseT offset_base(const BaseT &base, ptrdiff_t index)
/// - Return a new base that is offset from the provide base by 'index'
/// elements.
template <typename DerivedT, typename BaseT, typename T,
typename PointerT = T *, typename ReferenceT = T &>
class indexed_accessor_range_base {
public:
using RangeBaseT =
indexed_accessor_range_base<DerivedT, BaseT, T, PointerT, ReferenceT>;

/// An iterator element of this range.
class iterator : public indexed_accessor_iterator<iterator, BaseT, T,
PointerT, ReferenceT> {
public:
// Index into this iterator, invoking a static method on the derived type.
ReferenceT operator*() const {
return DerivedT::dereference_iterator(this->getBase(), this->getIndex());
}

private:
iterator(BaseT owner, ptrdiff_t curIndex)
: indexed_accessor_iterator<iterator, BaseT, T, PointerT, ReferenceT>(
owner, curIndex) {}

/// Allow access to the constructor.
friend indexed_accessor_range_base<DerivedT, BaseT, T, PointerT,
ReferenceT>;
};

indexed_accessor_range_base(iterator begin, iterator end)
: base(DerivedT::offset_base(begin.getBase(), begin.getIndex())),
count(end.getIndex() - begin.getIndex()) {}
indexed_accessor_range_base(const iterator_range<iterator> &range)
: indexed_accessor_range_base(range.begin(), range.end()) {}
indexed_accessor_range_base(BaseT base, ptrdiff_t count)
: base(base), count(count) {}

iterator begin() const { return iterator(base, 0); }
iterator end() const { return iterator(base, count); }
ReferenceT operator[](unsigned index) const {
assert(index < size() && "invalid index for value range");
return DerivedT::dereference_iterator(base, index);
}

/// Compare this range with another.
template <typename OtherT> bool operator==(const OtherT &other) {
return size() == std::distance(other.begin(), other.end()) &&
std::equal(begin(), end(), other.begin());
}

/// Return the size of this range.
size_t size() const { return count; }

/// Return if the range is empty.
bool empty() const { return size() == 0; }

/// Drop the first N elements, and keep M elements.
DerivedT slice(size_t n, size_t m) const {
assert(n + m <= size() && "invalid size specifiers");
return DerivedT(DerivedT::offset_base(base, n), m);
}

/// Drop the first n elements.
DerivedT drop_front(size_t n = 1) const {
assert(size() >= n && "Dropping more elements than exist");
return slice(n, size() - n);
}
/// Drop the last n elements.
DerivedT drop_back(size_t n = 1) const {
assert(size() >= n && "Dropping more elements than exist");
return DerivedT(base, size() - n);
}

/// Take the first n elements.
DerivedT take_front(size_t n = 1) const {
return n < size() ? drop_back(size() - n)
: static_cast<const DerivedT &>(*this);
}

/// Take the last n elements.
DerivedT take_back(size_t n = 1) const {
return n < size() ? drop_front(size() - n)
: static_cast<const DerivedT &>(*this);
}

/// Allow conversion to any type accepting an iterator_range.
template <typename RangeT, typename = std::enable_if_t<std::is_constructible<
RangeT, iterator_range<iterator>>::value>>
operator RangeT() const {
return RangeT(iterator_range<iterator>(*this));
}

protected:
indexed_accessor_range_base(const indexed_accessor_range_base &) = default;
indexed_accessor_range_base(indexed_accessor_range_base &&) = default;
indexed_accessor_range_base &
operator=(const indexed_accessor_range_base &) = default;

/// The base that owns the provided range of values.
BaseT base;
/// The size from the owning range.
ptrdiff_t count;
};
} // end namespace detail

/// This class provides an implementation of a range of
/// indexed_accessor_iterators where the base is not indexable. Ranges with
/// bases that are offsetable should derive from indexed_accessor_range_base
/// instead. Derived range classes are expected to implement the following
/// static method:
/// * ReferenceT dereference(const BaseT &base, ptrdiff_t index)
/// - Dereference an iterator pointing to a parent base at the given index.
template <typename DerivedT, typename BaseT, typename T,
typename PointerT = T *, typename ReferenceT = T &>
class indexed_accessor_range
: public detail::indexed_accessor_range_base<
DerivedT, std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT> {
public:
indexed_accessor_range(BaseT base, ptrdiff_t startIndex, ptrdiff_t count)
: detail::indexed_accessor_range_base<
DerivedT, std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT>(
std::make_pair(base, startIndex), count) {}
using detail::indexed_accessor_range_base<
DerivedT, std::pair<BaseT, ptrdiff_t>, T, PointerT,
ReferenceT>::indexed_accessor_range_base;

/// Returns the current base of the range.
const BaseT &getBase() const { return this->base.first; }

/// Returns the current start index of the range.
ptrdiff_t getStartIndex() const { return this->base.second; }

/// See `detail::indexed_accessor_range_base` for details.
static std::pair<BaseT, ptrdiff_t>
offset_base(const std::pair<BaseT, ptrdiff_t> &base, ptrdiff_t index) {
// We encode the internal base as a pair of the derived base and a start
// index into the derived base.
return std::make_pair(base.first, base.second + index);
}
/// See `detail::indexed_accessor_range_base` for details.
static ReferenceT
dereference_iterator(const std::pair<BaseT, ptrdiff_t> &base,
ptrdiff_t index) {
return DerivedT::dereference(base.first, base.second + index);
}
};

/// Given a container of pairs, return a range over the second elements.
template <typename ContainerTy> auto make_second_range(ContainerTy &&c) {
return llvm::map_range(
std::forward<ContainerTy>(c),
[](decltype((*std::begin(c))) elt) -> decltype((elt.second)) {
return elt.second;
});
}

//===----------------------------------------------------------------------===//
// Extra additions to <utility>
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions llvm/unittests/Support/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ add_llvm_unittest(SupportTests
FormatVariadicTest.cpp
GlobPatternTest.cpp
Host.cpp
IndexedAccessorTest.cpp
ItaniumManglingCanonicalizerTest.cpp
JSONTest.cpp
KnownBitsTest.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "gmock/gmock.h"

using namespace mlir;
using namespace mlir::detail;
using namespace llvm;
using namespace llvm::detail;

namespace {
/// Simple indexed accessor range that wraps an array.
Expand All @@ -24,7 +24,7 @@ struct ArrayIndexedAccessorRange
using indexed_accessor_range<ArrayIndexedAccessorRange<T>, T *,
T>::indexed_accessor_range;

/// See `indexed_accessor_range` for details.
/// See `llvm::indexed_accessor_range` for details.
static T &dereference(T *data, ptrdiff_t index) { return data[index]; }
};
} // end anonymous namespace
Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,16 +290,16 @@ class StructType : public Type::TypeBase<StructType, CompositeType,

/// Range class for element types.
class ElementTypeRange
: public ::mlir::detail::indexed_accessor_range_base<
: public ::llvm::detail::indexed_accessor_range_base<
ElementTypeRange, const Type *, Type, Type, Type> {
private:
using RangeBaseT::RangeBaseT;

/// See `mlir::detail::indexed_accessor_range_base` for details.
/// See `llvm::detail::indexed_accessor_range_base` for details.
static const Type *offset_base(const Type *object, ptrdiff_t index) {
return object + index;
}
/// See `mlir::detail::indexed_accessor_range_base` for details.
/// See `llvm::detail::indexed_accessor_range_base` for details.
static Type dereference_iterator(const Type *object, ptrdiff_t index) {
return object[index];
}
Expand Down
14 changes: 8 additions & 6 deletions mlir/include/mlir/IR/Attributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -648,13 +648,14 @@ using DenseIterPtrAndSplat =
template <typename ConcreteT, typename T, typename PointerT = T *,
typename ReferenceT = T &>
class DenseElementIndexedIteratorImpl
: public indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T,
PointerT, ReferenceT> {
: public llvm::indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T,
PointerT, ReferenceT> {
protected:
DenseElementIndexedIteratorImpl(const char *data, bool isSplat,
size_t dataIndex)
: indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T, PointerT,
ReferenceT>({data, isSplat}, dataIndex) {}
: llvm::indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T,
PointerT, ReferenceT>({data, isSplat},
dataIndex) {}

/// Return the current index for this iterator, adjusted for the case of a
/// splat.
Expand Down Expand Up @@ -746,8 +747,9 @@ class DenseElementsAttr
/// A utility iterator that allows walking over the internal Attribute values
/// of a DenseElementsAttr.
class AttributeElementIterator
: public indexed_accessor_iterator<AttributeElementIterator, const void *,
Attribute, Attribute, Attribute> {
: public llvm::indexed_accessor_iterator<AttributeElementIterator,
const void *, Attribute,
Attribute, Attribute> {
public:
/// Accesses the Attribute value at this iterator position.
Attribute operator*() const;
Expand Down
8 changes: 4 additions & 4 deletions mlir/include/mlir/IR/BlockSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,19 @@ class PredecessorIterator final

/// This class implements the successor iterators for Block.
class SuccessorRange final
: public detail::indexed_accessor_range_base<SuccessorRange, BlockOperand *,
Block *, Block *, Block *> {
: public llvm::detail::indexed_accessor_range_base<
SuccessorRange, BlockOperand *, Block *, Block *, Block *> {
public:
using RangeBaseT::RangeBaseT;
SuccessorRange(Block *block);
SuccessorRange(Operation *term);

private:
/// See `detail::indexed_accessor_range_base` for details.
/// See `llvm::detail::indexed_accessor_range_base` for details.
static BlockOperand *offset_base(BlockOperand *object, ptrdiff_t index) {
return object + index;
}
/// See `detail::indexed_accessor_range_base` for details.
/// See `llvm::detail::indexed_accessor_range_base` for details.
static Block *dereference_iterator(BlockOperand *object, ptrdiff_t index) {
return object[index].get();
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class OpAsmPrinter {
void printArrowTypeList(TypeRange &&types) {
auto &os = getStream() << " -> ";

bool wrapped = !has_single_element(types) ||
bool wrapped = !llvm::hasSingleElement(types) ||
(*types.begin()).template isa<FunctionType>();
if (wrapped)
os << '(';
Expand Down
Loading

0 comments on commit 204c3b5

Please sign in to comment.