Skip to content

Commit

Permalink
[mlir][python] Fix build on windows
Browse files Browse the repository at this point in the history
Reviewed By: stella.stamenova, ashay-github

Differential Revision: https://reviews.llvm.org/D131906
  • Loading branch information
Mogball committed Aug 15, 2022
1 parent c228410 commit 133624a
Showing 1 changed file with 14 additions and 24 deletions.
38 changes: 14 additions & 24 deletions mlir/lib/Bindings/Python/IRAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,9 @@ static T pyTryCast(py::handle object) {
/// A python-wrapped dense array attribute with an element type and a derived
/// implementation class.
template <typename EltTy, typename DerivedT>
class PyDenseArrayAttribute
: public PyConcreteAttribute<PyDenseArrayAttribute<EltTy, DerivedT>> {
class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
public:
static constexpr typename PyConcreteAttribute<
PyDenseArrayAttribute<EltTy, DerivedT>>::IsAFunctionTy isaFunction =
DerivedT::isaFunction;
static constexpr const char *pyClassName = DerivedT::pyClassName;
using PyConcreteAttribute<
PyDenseArrayAttribute<EltTy, DerivedT>>::PyConcreteAttribute;
using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;

/// Iterator over the integer elements of a dense array.
class PyDenseArrayIterator {
Expand Down Expand Up @@ -158,33 +152,29 @@ class PyDenseArrayAttribute
EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }

/// Bind the attribute class.
static void bindDerived(typename PyConcreteAttribute<
PyDenseArrayAttribute<EltTy, DerivedT>>::ClassTy &c) {
static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
// Bind the constructor.
c.def_static(
"get",
[](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
MlirAttribute attr =
DerivedT::getAttribute(ctx->get(), values.size(), values.data());
return PyDenseArrayAttribute<EltTy, DerivedT>(ctx->getRef(), attr);
return DerivedT(ctx->getRef(), attr);
},
py::arg("values"), py::arg("context") = py::none(),
"Gets a uniqued dense array attribute");
// Bind the array methods.
c.def("__getitem__",
[](PyDenseArrayAttribute<EltTy, DerivedT> &arr, intptr_t i) {
if (i >= mlirDenseArrayGetNumElements(arr))
throw py::index_error("DenseArray index out of range");
return arr.getItem(i);
});
c.def("__len__", [](const PyDenseArrayAttribute<EltTy, DerivedT> &arr) {
return mlirDenseArrayGetNumElements(arr);
c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
if (i >= mlirDenseArrayGetNumElements(arr))
throw py::index_error("DenseArray index out of range");
return arr.getItem(i);
});
c.def("__iter__", [](const PyDenseArrayAttribute<EltTy, DerivedT> &arr) {
return PyDenseArrayIterator(arr);
c.def("__len__", [](const DerivedT &arr) {
return mlirDenseArrayGetNumElements(arr);
});
c.def("__add__", [](PyDenseArrayAttribute<EltTy, DerivedT> &arr,
py::list extras) {
c.def("__iter__",
[](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
c.def("__add__", [](DerivedT &arr, py::list extras) {
std::vector<EltTy> values;
intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
values.reserve(numOldElements + py::len(extras));
Expand All @@ -194,7 +184,7 @@ class PyDenseArrayAttribute
values.push_back(pyTryCast<EltTy>(attr));
MlirAttribute attr = DerivedT::getAttribute(arr.getContext()->get(),
values.size(), values.data());
return PyDenseArrayAttribute<EltTy, DerivedT>(arr.getContext(), attr);
return DerivedT(arr.getContext(), attr);
});
}
};
Expand Down

0 comments on commit 133624a

Please sign in to comment.