Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] Add walk_with_filter to walk subset of IR #7591

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions integration_test/Bindings/Python/support/walk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# REQUIRES: bindings_python
# RUN: %PYTHON% %s | FileCheck %s

import circt
from circt.support import walk_with_filter
from circt.dialects import hw
from circt.ir import Context, Module, WalkOrder, WalkResult


def test_walk_with_filter():
ctx = Context()
circt.register_dialects(ctx)
module = Module.parse(
r"""
builtin.module {
hw.module @f() {
hw.output
}
}
""",
ctx,
)

def callback(op):
print(op.name)
return WalkResult.ADVANCE

# Test post-order walk.
# CHECK: Post-order
# CHECK-NEXT: hw.output
# CHECK-NEXT: hw.module
# CHECK-NOT: builtin.module
print("Post-order")
walk_with_filter(module.operation, [hw.HWModuleOp, hw.OutputOp], callback,
WalkOrder.POST_ORDER)

# Test pre-order walk.
# CHECK-NEXT: Pre-order
# CHECK-NOT: builtin.module
# CHECK-NEXT: hw.module
# CHECK-NEXT: hw.output
print("Pre-order")
walk_with_filter(module.operation, [hw.HWModuleOp, hw.OutputOp], callback,
WalkOrder.PRE_ORDER)

# Test interrupt.
# CHECK-NEXT: Interrupt post-order
# CHECK-NEXT: hw.output
print("Interrupt post-order")

def interrupt_callback(op):
print(op.name)
return WalkResult.INTERRUPT

walk_with_filter(module.operation, [hw.OutputOp], interrupt_callback,
WalkOrder.POST_ORDER)

# Test exception.
# CHECK: Exception
# CHECK-NEXT: hw.output
# CHECK-NEXT: Exception raised
print("Exception")

def exception_callback(op):
print(op.name)
raise ValueError
return WalkResult.ADVANCE

try:
walk_with_filter(module.operation, [hw.OutputOp], exception_callback,
WalkOrder.POST_ORDER)
except RuntimeError:
print("Exception raised")


test_walk_with_filter()
4 changes: 3 additions & 1 deletion lib/Bindings/Python/CIRCTModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "DialectModules.h"
#include "CIRCTModules.h"

#include "circt-c/Conversion.h"
#include "circt-c/Dialect/Comb.h"
Expand Down Expand Up @@ -142,4 +142,6 @@ PYBIND11_MODULE(_circt, m) {
circt::python::populateDialectOMSubmodule(om);
py::module sv = m.def_submodule("_sv", "SV API");
circt::python::populateDialectSVSubmodule(sv);
py::module support = m.def_submodule("_support", "CIRCT support");
circt::python::populateSupportSubmodule(support);
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
//===- DialectModules.h - Populate submodules -----------------------------===//
//===- CIRCTModules.h - Populate submodules -------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Functions to populate each dialect's submodule (if provided).
// Functions to populate submodules in CIRCT (if provided).
//
//===----------------------------------------------------------------------===//

#ifndef CIRCT_BINDINGS_PYTHON_DIALECTMODULES_H
#define CIRCT_BINDINGS_PYTHON_DIALECTMODULES_H
#ifndef CIRCT_BINDINGS_PYTHON_CIRCTMODULES_H
#define CIRCT_BINDINGS_PYTHON_CIRCTMODULES_H

#include <pybind11/pybind11.h>

Expand All @@ -24,6 +24,7 @@ void populateDialectMSFTSubmodule(pybind11::module &m);
void populateDialectOMSubmodule(pybind11::module &m);
void populateDialectSeqSubmodule(pybind11::module &m);
void populateDialectSVSubmodule(pybind11::module &m);
void populateSupportSubmodule(pybind11::module &m);

} // namespace python
} // namespace circt
Expand Down
1 change: 1 addition & 0 deletions lib/Bindings/Python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ declare_mlir_python_extension(CIRCTBindingsPythonExtension
OMModule.cpp
MSFTModule.cpp
SeqModule.cpp
SupportModule.cpp
SVModule.cpp
EMBED_CAPI_LINK_LIBS
CIRCTCAPIComb
Expand Down
2 changes: 1 addition & 1 deletion lib/Bindings/Python/ESIModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "DialectModules.h"
#include "CIRCTModules.h"

#include "circt/Dialect/ESI/ESIDialect.h"

Expand Down
2 changes: 1 addition & 1 deletion lib/Bindings/Python/HWModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "DialectModules.h"
#include "CIRCTModules.h"

#include "circt-c/Dialect/HW.h"

Expand Down
2 changes: 1 addition & 1 deletion lib/Bindings/Python/MSFTModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "DialectModules.h"
#include "CIRCTModules.h"

#include "circt-c/Dialect/MSFT.h"
#include "circt/Dialect/MSFT/MSFTDialect.h"
Expand Down
2 changes: 1 addition & 1 deletion lib/Bindings/Python/OMModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "DialectModules.h"
#include "CIRCTModules.h"
#include "circt-c/Dialect/HW.h"
#include "circt-c/Dialect/OM.h"
#include "mlir-c/BuiltinAttributes.h"
Expand Down
2 changes: 1 addition & 1 deletion lib/Bindings/Python/SVModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "DialectModules.h"
#include "CIRCTModules.h"

#include "circt-c/Dialect/SV.h"
#include "mlir-c/Bindings/Python/Interop.h"
Expand Down
2 changes: 1 addition & 1 deletion lib/Bindings/Python/SeqModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "DialectModules.h"
#include "CIRCTModules.h"

#include "circt-c/Dialect/Seq.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
Expand Down
86 changes: 86 additions & 0 deletions lib/Bindings/Python/SupportModule.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
//===- SupportModule.cpp - Support API pybind module ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "CIRCTModules.h"

#include "mlir/Bindings/Python/PybindAdaptors.h"

#include "PybindUtils.h"
#include "mlir-c/Support.h"
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>

namespace py = pybind11;

using namespace circt;
using namespace mlir::python::adaptors;

/// Populate the support python module.
void circt::python::populateSupportSubmodule(py::module &m) {
m.doc() = "CIRCT Python utils";
// Walk with filter.
m.def(
"_walk_with_filter",
[](MlirOperation operation, std::vector<std::string> op_names,
std::function<MlirWalkResult(MlirOperation)> callback,
MlirWalkOrder walkOrder) {
struct UserData {
std::function<MlirWalkResult(MlirOperation)> callback;
bool gotException;
std::string exceptionWhat;
py::object exceptionType;
std::vector<MlirIdentifier> op_names;
};

std::vector<MlirIdentifier> op_names_identifiers;

// Construct MlirIdentifier from string to perform pointer comparison.
for (auto &op_name : op_names)
op_names_identifiers.push_back(mlirIdentifierGet(
mlirOperationGetContext(operation),
mlirStringRefCreateFromCString(op_name.c_str())));

UserData userData{callback, false, {}, {}, op_names_identifiers};
MlirOperationWalkCallback walkCallback = [](MlirOperation op,
void *userData) {
UserData *calleeUserData = static_cast<UserData *>(userData);
auto op_name = mlirOperationGetName(op);

// Check if the operation name is in the filter.
bool inFilter = false;
for (auto &op_name_identifier : calleeUserData->op_names) {
if (mlirIdentifierEqual(op_name, op_name_identifier)) {
inFilter = true;
break;
}
}

// If the operation name is not in the filter, skip it.
if (!inFilter)
return MlirWalkResult::MlirWalkResultAdvance;

try {
return (calleeUserData->callback)(op);
} catch (py::error_already_set &e) {
calleeUserData->gotException = true;
calleeUserData->exceptionWhat = e.what();
calleeUserData->exceptionType = e.type();
return MlirWalkResult::MlirWalkResultInterrupt;
}
};
mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
if (userData.gotException) {
std::string message("Exception raised in callback: ");
message.append(userData.exceptionWhat);
throw std::runtime_error(message);
}
},
py::arg("op"), py::arg("op_names"), py::arg("callback"),
py::arg("walk_order"));
}
12 changes: 12 additions & 0 deletions lib/Bindings/Python/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from . import ir

from ._mlir_libs._circt._support import _walk_with_filter
from .ir import Operation
from contextlib import AbstractContextManager
from contextvars import ContextVar
from typing import List
Expand Down Expand Up @@ -409,3 +411,13 @@ def create_default_value(self, index, data_type, arg_name):
def operation(self):
"""Get the operation associated with this builder."""
return self.opview.operation


# Helper function to walk operation with a filter on operation names.
# `op_views` is a list of operation views to visit. This is a wrapper
# around the C++ implementation of walk_with_filter.
def walk_with_filter(operation: Operation, op_views: List[ir.OpView], callback,
walk_order):
op_names_identifiers = [name.OPERATION_NAME for name in op_views]
return _walk_with_filter(operation, op_names_identifiers, callback,
walk_order)
Loading