Skip to content

Commit

Permalink
[python] Add walk_with_filter to walk subset of IR (#7591)
Browse files Browse the repository at this point in the history
This adds `walk_with_filter` python method to invoke callbacks
only for subset of operations. We require GIL to call python function
so this could improve performance of the tools based on Python API.

```
Starting walk_with_filter
walk_with_filter elapsed time: 0.005462 seconds cnt=1
Starting operation.walk
operation.walk elapsed time: 1.061360 seconds cnt=2
```
  • Loading branch information
uenoku authored Sep 10, 2024
1 parent 6638aaf commit c433d61
Show file tree
Hide file tree
Showing 12 changed files with 189 additions and 11 deletions.
76 changes: 76 additions & 0 deletions integration_test/Bindings/Python/support/walk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# REQUIRES: bindings_python
# RUN: %PYTHON% %s | FileCheck %s

import circt
from circt.support import walk_with_filter
from circt.dialects import hw
from circt.ir import Context, Module, WalkOrder, WalkResult


def test_walk_with_filter():
ctx = Context()
circt.register_dialects(ctx)
module = Module.parse(
r"""
builtin.module {
hw.module @f() {
hw.output
}
}
""",
ctx,
)

def callback(op):
print(op.name)
return WalkResult.ADVANCE

# Test post-order walk.
# CHECK: Post-order
# CHECK-NEXT: hw.output
# CHECK-NEXT: hw.module
# CHECK-NOT: builtin.module
print("Post-order")
walk_with_filter(module.operation, [hw.HWModuleOp, hw.OutputOp], callback,
WalkOrder.POST_ORDER)

# Test pre-order walk.
# CHECK-NEXT: Pre-order
# CHECK-NOT: builtin.module
# CHECK-NEXT: hw.module
# CHECK-NEXT: hw.output
print("Pre-order")
walk_with_filter(module.operation, [hw.HWModuleOp, hw.OutputOp], callback,
WalkOrder.PRE_ORDER)

# Test interrupt.
# CHECK-NEXT: Interrupt post-order
# CHECK-NEXT: hw.output
print("Interrupt post-order")

def interrupt_callback(op):
print(op.name)
return WalkResult.INTERRUPT

walk_with_filter(module.operation, [hw.OutputOp], interrupt_callback,
WalkOrder.POST_ORDER)

# Test exception.
# CHECK: Exception
# CHECK-NEXT: hw.output
# CHECK-NEXT: Exception raised
print("Exception")

def exception_callback(op):
print(op.name)
raise ValueError
return WalkResult.ADVANCE

try:
walk_with_filter(module.operation, [hw.OutputOp], exception_callback,
WalkOrder.POST_ORDER)
except RuntimeError:
print("Exception raised")


test_walk_with_filter()
4 changes: 3 additions & 1 deletion lib/Bindings/Python/CIRCTModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "DialectModules.h"
#include "CIRCTModules.h"

#include "circt-c/Conversion.h"
#include "circt-c/Dialect/Comb.h"
Expand Down Expand Up @@ -142,4 +142,6 @@ PYBIND11_MODULE(_circt, m) {
circt::python::populateDialectOMSubmodule(om);
py::module sv = m.def_submodule("_sv", "SV API");
circt::python::populateDialectSVSubmodule(sv);
py::module support = m.def_submodule("_support", "CIRCT support");
circt::python::populateSupportSubmodule(support);
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
//===- DialectModules.h - Populate submodules -----------------------------===//
//===- CIRCTModules.h - Populate submodules -------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Functions to populate each dialect's submodule (if provided).
// Functions to populate submodules in CIRCT (if provided).
//
//===----------------------------------------------------------------------===//

#ifndef CIRCT_BINDINGS_PYTHON_DIALECTMODULES_H
#define CIRCT_BINDINGS_PYTHON_DIALECTMODULES_H
#ifndef CIRCT_BINDINGS_PYTHON_CIRCTMODULES_H
#define CIRCT_BINDINGS_PYTHON_CIRCTMODULES_H

#include <pybind11/pybind11.h>

Expand All @@ -24,6 +24,7 @@ void populateDialectMSFTSubmodule(pybind11::module &m);
void populateDialectOMSubmodule(pybind11::module &m);
void populateDialectSeqSubmodule(pybind11::module &m);
void populateDialectSVSubmodule(pybind11::module &m);
void populateSupportSubmodule(pybind11::module &m);

} // namespace python
} // namespace circt
Expand Down
1 change: 1 addition & 0 deletions lib/Bindings/Python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ declare_mlir_python_extension(CIRCTBindingsPythonExtension
OMModule.cpp
MSFTModule.cpp
SeqModule.cpp
SupportModule.cpp
SVModule.cpp
EMBED_CAPI_LINK_LIBS
CIRCTCAPIComb
Expand Down
2 changes: 1 addition & 1 deletion lib/Bindings/Python/ESIModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "DialectModules.h"
#include "CIRCTModules.h"

#include "circt/Dialect/ESI/ESIDialect.h"

Expand Down
2 changes: 1 addition & 1 deletion lib/Bindings/Python/HWModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "DialectModules.h"
#include "CIRCTModules.h"

#include "circt-c/Dialect/HW.h"

Expand Down
2 changes: 1 addition & 1 deletion lib/Bindings/Python/MSFTModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "DialectModules.h"
#include "CIRCTModules.h"

#include "circt-c/Dialect/MSFT.h"
#include "circt/Dialect/MSFT/MSFTDialect.h"
Expand Down
2 changes: 1 addition & 1 deletion lib/Bindings/Python/OMModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "DialectModules.h"
#include "CIRCTModules.h"
#include "circt-c/Dialect/HW.h"
#include "circt-c/Dialect/OM.h"
#include "mlir-c/BuiltinAttributes.h"
Expand Down
2 changes: 1 addition & 1 deletion lib/Bindings/Python/SVModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "DialectModules.h"
#include "CIRCTModules.h"

#include "circt-c/Dialect/SV.h"
#include "mlir-c/Bindings/Python/Interop.h"
Expand Down
2 changes: 1 addition & 1 deletion lib/Bindings/Python/SeqModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "DialectModules.h"
#include "CIRCTModules.h"

#include "circt-c/Dialect/Seq.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
Expand Down
86 changes: 86 additions & 0 deletions lib/Bindings/Python/SupportModule.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
//===- SupportModule.cpp - Support API pybind module ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "CIRCTModules.h"

#include "mlir/Bindings/Python/PybindAdaptors.h"

#include "PybindUtils.h"
#include "mlir-c/Support.h"
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>

namespace py = pybind11;

using namespace circt;
using namespace mlir::python::adaptors;

/// Populate the support python module.
void circt::python::populateSupportSubmodule(py::module &m) {
m.doc() = "CIRCT Python utils";
// Walk with filter.
m.def(
"_walk_with_filter",
[](MlirOperation operation, std::vector<std::string> op_names,
std::function<MlirWalkResult(MlirOperation)> callback,
MlirWalkOrder walkOrder) {
struct UserData {
std::function<MlirWalkResult(MlirOperation)> callback;
bool gotException;
std::string exceptionWhat;
py::object exceptionType;
std::vector<MlirIdentifier> op_names;
};

std::vector<MlirIdentifier> op_names_identifiers;

// Construct MlirIdentifier from string to perform pointer comparison.
for (auto &op_name : op_names)
op_names_identifiers.push_back(mlirIdentifierGet(
mlirOperationGetContext(operation),
mlirStringRefCreateFromCString(op_name.c_str())));

UserData userData{callback, false, {}, {}, op_names_identifiers};
MlirOperationWalkCallback walkCallback = [](MlirOperation op,
void *userData) {
UserData *calleeUserData = static_cast<UserData *>(userData);
auto op_name = mlirOperationGetName(op);

// Check if the operation name is in the filter.
bool inFilter = false;
for (auto &op_name_identifier : calleeUserData->op_names) {
if (mlirIdentifierEqual(op_name, op_name_identifier)) {
inFilter = true;
break;
}
}

// If the operation name is not in the filter, skip it.
if (!inFilter)
return MlirWalkResult::MlirWalkResultAdvance;

try {
return (calleeUserData->callback)(op);
} catch (py::error_already_set &e) {
calleeUserData->gotException = true;
calleeUserData->exceptionWhat = e.what();
calleeUserData->exceptionType = e.type();
return MlirWalkResult::MlirWalkResultInterrupt;
}
};
mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
if (userData.gotException) {
std::string message("Exception raised in callback: ");
message.append(userData.exceptionWhat);
throw std::runtime_error(message);
}
},
py::arg("op"), py::arg("op_names"), py::arg("callback"),
py::arg("walk_order"));
}
12 changes: 12 additions & 0 deletions lib/Bindings/Python/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from . import ir

from ._mlir_libs._circt._support import _walk_with_filter
from .ir import Operation
from contextlib import AbstractContextManager
from contextvars import ContextVar
from typing import List
Expand Down Expand Up @@ -409,3 +411,13 @@ def create_default_value(self, index, data_type, arg_name):
def operation(self):
"""Get the operation associated with this builder."""
return self.opview.operation


# Helper function to walk operation with a filter on operation names.
# `op_views` is a list of operation views to visit. This is a wrapper
# around the C++ implementation of walk_with_filter.
def walk_with_filter(operation: Operation, op_views: List[ir.OpView], callback,
walk_order):
op_names_identifiers = [name.OPERATION_NAME for name in op_views]
return _walk_with_filter(operation, op_names_identifiers, callback,
walk_order)

0 comments on commit c433d61

Please sign in to comment.