-
Notifications
You must be signed in to change notification settings - Fork 307
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[python] Add
walk_with_filter
to walk subset of IR
This adds `walk_with_filter` python method to invoke callbacks only for subset of operations. We require GIL to call python function so this could improve performance of the tools based on Python API. ``` Starting walk_with_filter walk_with_filter elapsed time: 0.005462 seconds cnt=1 Starting operation.walk operation.walk elapsed time: 1.061360 seconds cnt=2 ```
- Loading branch information
Showing
12 changed files
with
189 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# REQUIRES: bindings_python | ||
# RUN: %PYTHON% %s | FileCheck %s | ||
|
||
import circt | ||
from circt.support import walk_with_filter | ||
from circt.dialects import hw | ||
from circt.ir import Context, Module, WalkOrder, WalkResult | ||
|
||
|
||
def test_walk_with_filter(): | ||
ctx = Context() | ||
circt.register_dialects(ctx) | ||
module = Module.parse( | ||
r""" | ||
builtin.module { | ||
hw.module @f() { | ||
hw.output | ||
} | ||
} | ||
""", | ||
ctx, | ||
) | ||
|
||
def callback(op): | ||
print(op.name) | ||
return WalkResult.ADVANCE | ||
|
||
# Test post-order walk. | ||
# CHECK: Post-order | ||
# CHECK-NEXT: hw.output | ||
# CHECK-NEXT: hw.module | ||
# CHECK-NOT: builtin.module | ||
print("Post-order") | ||
walk_with_filter(module.operation, [hw.HWModuleOp, hw.OutputOp], callback, | ||
WalkOrder.POST_ORDER) | ||
|
||
# Test pre-order walk. | ||
# CHECK-NEXT: Pre-order | ||
# CHECK-NOT: builtin.module | ||
# CHECK-NEXT: hw.module | ||
# CHECK-NEXT: hw.output | ||
print("Pre-order") | ||
walk_with_filter(module.operation, [hw.HWModuleOp, hw.OutputOp], callback, | ||
WalkOrder.PRE_ORDER) | ||
|
||
# Test interrupt. | ||
# CHECK-NEXT: Interrupt post-order | ||
# CHECK-NEXT: hw.output | ||
print("Interrupt post-order") | ||
|
||
def interrupt_callback(op): | ||
print(op.name) | ||
return WalkResult.INTERRUPT | ||
|
||
walk_with_filter(module.operation, [hw.OutputOp], interrupt_callback, | ||
WalkOrder.POST_ORDER) | ||
|
||
# Test exception. | ||
# CHECK: Exception | ||
# CHECK-NEXT: hw.output | ||
# CHECK-NEXT: Exception raised | ||
print("Exception") | ||
|
||
def exception_callback(op): | ||
print(op.name) | ||
raise ValueError | ||
return WalkResult.ADVANCE | ||
|
||
try: | ||
walk_with_filter(module.operation, [hw.OutputOp], exception_callback, | ||
WalkOrder.POST_ORDER) | ||
except RuntimeError: | ||
print("Exception raised") | ||
|
||
|
||
test_walk_with_filter() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
//===- SupportModule.cpp - Support API pybind module ----------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "CIRCTModules.h" | ||
|
||
#include "mlir/Bindings/Python/PybindAdaptors.h" | ||
|
||
#include "PybindUtils.h" | ||
#include "mlir-c/Support.h" | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/pytypes.h> | ||
#include <pybind11/stl.h> | ||
|
||
namespace py = pybind11; | ||
|
||
using namespace circt; | ||
using namespace mlir::python::adaptors; | ||
|
||
/// Populate the support python module. | ||
void circt::python::populateSupportSubmodule(py::module &m) { | ||
m.doc() = "CIRCT Python utils"; | ||
// Walk with filter. | ||
m.def( | ||
"_walk_with_filter", | ||
[](MlirOperation operation, std::vector<std::string> op_names, | ||
std::function<MlirWalkResult(MlirOperation)> callback, | ||
MlirWalkOrder walkOrder) { | ||
struct UserData { | ||
std::function<MlirWalkResult(MlirOperation)> callback; | ||
bool gotException; | ||
std::string exceptionWhat; | ||
py::object exceptionType; | ||
std::vector<MlirIdentifier> op_names; | ||
}; | ||
|
||
std::vector<MlirIdentifier> op_names_identifiers; | ||
|
||
// Construct MlirIdentifier from string to perform pointer comparison. | ||
for (auto &op_name : op_names) | ||
op_names_identifiers.push_back(mlirIdentifierGet( | ||
mlirOperationGetContext(operation), | ||
mlirStringRefCreateFromCString(op_name.c_str()))); | ||
|
||
UserData userData{callback, false, {}, {}, op_names_identifiers}; | ||
MlirOperationWalkCallback walkCallback = [](MlirOperation op, | ||
void *userData) { | ||
UserData *calleeUserData = static_cast<UserData *>(userData); | ||
auto op_name = mlirOperationGetName(op); | ||
|
||
// Check if the operation name is in the filter. | ||
bool inFilter = false; | ||
for (auto &op_name_identifier : calleeUserData->op_names) { | ||
if (mlirIdentifierEqual(op_name, op_name_identifier)) { | ||
inFilter = true; | ||
break; | ||
} | ||
} | ||
|
||
// If the operation name is not in the filter, skip it. | ||
if (!inFilter) | ||
return MlirWalkResult::MlirWalkResultAdvance; | ||
|
||
try { | ||
return (calleeUserData->callback)(op); | ||
} catch (py::error_already_set &e) { | ||
calleeUserData->gotException = true; | ||
calleeUserData->exceptionWhat = e.what(); | ||
calleeUserData->exceptionType = e.type(); | ||
return MlirWalkResult::MlirWalkResultInterrupt; | ||
} | ||
}; | ||
mlirOperationWalk(operation, walkCallback, &userData, walkOrder); | ||
if (userData.gotException) { | ||
std::string message("Exception raised in callback: "); | ||
message.append(userData.exceptionWhat); | ||
throw std::runtime_error(message); | ||
} | ||
}, | ||
py::arg("op"), py::arg("op_names"), py::arg("callback"), | ||
py::arg("walk_order")); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters