From c1a179c817578bc3f71e2e590930c713e7c6d2c9 Mon Sep 17 00:00:00 2001
From: Mashed Potato <38517644+potatomashed@users.noreply.github.com>
Date: Sat, 28 Dec 2024 10:56:14 -0800
Subject: [PATCH] AST Parser
---
CMakeLists.txt | 2 +-
README.md | 93 +-
include/mlc/core/typing.h | 38 +-
include/mlc/core/utils.h | 7 +
include/mlc/printer/ir_printer.h | 13 +-
pyproject.toml | 5 +-
python/mlc/__init__.py | 2 +-
python/mlc/_cython/core.pyx | 8 +-
python/mlc/ast/__init__.py | 4 -
python/mlc/ast/mlc_ast.py | 855 ------------------
python/mlc/ast/translate.py | 109 ---
python/mlc/core/typing.py | 34 +-
python/mlc/dataclasses/utils.py | 7 +-
python/mlc/parser/__init__.py | 3 +
python/mlc/parser/diagnostic.py | 116 +++
.../mlc/{ast/inspection.py => parser/env.py} | 182 ++--
python/mlc/parser/parser.py | 220 +++++
python/mlc/printer/ir_printer.py | 24 +-
python/mlc/testing/__init__.py | 0
python/mlc/testing/toy_ir/__init__.py | 3 +
python/mlc/testing/toy_ir/ir.py | 82 ++
python/mlc/testing/toy_ir/ir_builder.py | 69 ++
python/mlc/testing/toy_ir/parser.py | 66 ++
tests/cpp/test_base_optional.cc | 8 +-
tests/cpp/test_base_ref.cc | 4 +-
tests/python/test_parser_toy_ir_parser.py | 28 +
tests/python/test_printer_ir_printer.py | 73 +-
27 files changed, 793 insertions(+), 1262 deletions(-)
delete mode 100644 python/mlc/ast/__init__.py
delete mode 100644 python/mlc/ast/mlc_ast.py
delete mode 100644 python/mlc/ast/translate.py
create mode 100644 python/mlc/parser/__init__.py
create mode 100644 python/mlc/parser/diagnostic.py
rename python/mlc/{ast/inspection.py => parser/env.py} (60%)
create mode 100644 python/mlc/parser/parser.py
create mode 100644 python/mlc/testing/__init__.py
create mode 100644 python/mlc/testing/toy_ir/__init__.py
create mode 100644 python/mlc/testing/toy_ir/ir.py
create mode 100644 python/mlc/testing/toy_ir/ir_builder.py
create mode 100644 python/mlc/testing/toy_ir/parser.py
create mode 100644 tests/python/test_parser_toy_ir_parser.py
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 8f40d945..961ccd48 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.15)
project(
mlc
- VERSION 0.0.13
+ VERSION 0.0.14
DESCRIPTION "MLC-Python"
LANGUAGES C CXX
)
diff --git a/README.md b/README.md
index 2b93b258..a9fdd74f 100644
--- a/README.md
+++ b/README.md
@@ -121,87 +121,21 @@ ValueError: Structural equality check failed at {root}.rhs.b: Inconsistent bindi
### :snake: Text Formats in Python
-**IR Printer.** By defining an `__ir_print__` method, which converts an IR node to MLC's Python-style AST, MLC's `IRPrinter` handles variable scoping, renaming and syntax highlighting automatically for a text format based on Python syntax.
+**Printer.** MLC converts an IR node to Python AST by looking up the `__ir_print__` method.
-Defining Python-based text format on a toy IR using `__ir_print__`.
+**[[Example](https://github.com/mlc-ai/mlc-python/blob/main/python/mlc/testing/toy_ir/ir.py)]**. Copy the toy IR definition to REPL and then create a `Func` node below:
```python
-import mlc.dataclasses as mlcd
-import mlc.printer as mlcp
-from mlc.printer import ast as mlt
-
-@mlcd.py_class
-class Expr(mlcd.PyClass): ...
-
-@mlcd.py_class
-class Stmt(mlcd.PyClass): ...
-
-@mlcd.py_class
-class Var(Expr):
- name: str
- def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
- if not printer.var_is_defined(obj=self):
- printer.var_def(obj=self, frame=printer.frames[-1], name=self.name)
- return printer.var_get(obj=self)
-
-@mlcd.py_class
-class Add(Expr):
- lhs: Expr
- rhs: Expr
- def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
- lhs: mlt.Expr = printer(obj=self.lhs, path=path["a"])
- rhs: mlt.Expr = printer(obj=self.rhs, path=path["b"])
- return lhs + rhs
-
-@mlcd.py_class
-class Assign(Stmt):
- lhs: Var
- rhs: Expr
- def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
- rhs: mlt.Expr = printer(obj=self.rhs, path=path["b"])
- printer.var_def(obj=self.lhs, frame=printer.frames[-1], name=self.lhs.name)
- lhs: mlt.Expr = printer(obj=self.lhs, path=path["a"])
- return mlt.Assign(lhs=lhs, rhs=rhs)
-
-@mlcd.py_class
-class Func(mlcd.PyClass):
- name: str
- args: list[Var]
- stmts: list[Stmt]
- ret: Var
- def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
- with printer.with_frame(mlcp.DefaultFrame()):
- for arg in self.args:
- printer.var_def(obj=arg, frame=printer.frames[-1], name=arg.name)
- args: list[mlt.Expr] = [printer(obj=arg, path=path["args"][i]) for i, arg in enumerate(self.args)]
- stmts: list[mlt.Expr] = [printer(obj=stmt, path=path["stmts"][i]) for i, stmt in enumerate(self.stmts)]
- ret_stmt = mlt.Return(printer(obj=self.ret, path=path["ret"]))
- return mlt.Function(
- name=mlt.Id(self.name),
- args=[mlt.Assign(lhs=arg, rhs=None) for arg in args],
- decorators=[],
- return_type=None,
- body=[*stmts, ret_stmt],
- )
-
-# An example IR:
-a, b, c, d, e = Var("a"), Var("b"), Var("c"), Var("d"), Var("e")
-f = Func(
- name="f",
- args=[a, b, c],
+>>> a, b, c, d, e = Var("a"), Var("b"), Var("c"), Var("d"), Var("e")
+>>> f = Func("f", [a, b, c],
stmts=[
Assign(lhs=d, rhs=Add(a, b)), # d = a + b
Assign(lhs=e, rhs=Add(d, c)), # e = d + c
],
- ret=e,
-)
+ ret=e)
```
-
-
-Two printer APIs are provided for Python-based text format:
-- `mlc.printer.to_python` that converts an IR fragment to Python text, and
-- `mlc.printer.print_python` that further renders the text with proper syntax highlighting.
+- Method `mlc.printer.to_python` converts an IR node to Python-based text;
```python
>>> print(mlcp.to_python(f)) # Stringify to Python
@@ -209,12 +143,25 @@ def f(a, b, c):
d = a + b
e = d + c
return e
+```
+
+- Method `mlc.printer.print_python` further renders the text with proper syntax highlighting. [[Screenshot](https://raw.githubusercontent.com/gist/potatomashed/5a9b20edbdde1b9a91a360baa6bce9ff/raw/3c68031eaba0620a93add270f8ad7ed2c8724a78/mlc-python-printer.svg)]
+
+```python
>>> mlcp.print_python(f) # Syntax highlighting
```
+**AST Parser.** MLC has a concise set of APIs for implementing parser with Python's AST module, including:
+- Inspection API that obtains source code of a Python class or function and the variables they capture;
+- Variable management APIs that help with proper scoping;
+- AST fragment evaluation APIs;
+- Error rendering APIs.
+
+**[[Example](https://github.com/mlc-ai/mlc-python/blob/main/python/mlc/testing/toy_ir/parser.py)]**. With MLC APIs, a parser can be implemented with 100 lines of code for the Python text format above defined by `__ir_printer__`.
+
### :zap: Zero-Copy Interoperability with C++ Plugins
-TBD
+🚧 Under construction.
## :fuelpump: Development
diff --git a/include/mlc/core/typing.h b/include/mlc/core/typing.h
index 0021b792..1d8a99da 100644
--- a/include/mlc/core/typing.h
+++ b/include/mlc/core/typing.h
@@ -122,8 +122,7 @@ struct AtomicType : public Type {
};
struct PtrTypeObj : protected MLCTypingPtr {
- explicit PtrTypeObj(Type ty) : MLCTypingPtr{} { this->TyMutable() = ty; }
- Type Ty() const { return Type(reinterpret_cast &>(this->MLCTypingPtr::ty)); }
+ explicit PtrTypeObj(Type ty) : MLCTypingPtr{} { this->TyMut() = ty; }
::mlc::Str __str__() const {
std::ostringstream os;
os << "Ptr[" << this->Ty() << "]";
@@ -138,20 +137,20 @@ struct PtrTypeObj : protected MLCTypingPtr {
MLC_DEF_STATIC_TYPE(PtrTypeObj, TypeObj, MLCTypeIndex::kMLCTypingPtr, "mlc.core.typing.PtrType");
private:
- Type &TyMutable() { return reinterpret_cast(this->MLCTypingPtr::ty); }
+ Type &TyMut() { return reinterpret_cast(this->MLCTypingPtr::ty); }
+ Type Ty() const { return Type(reinterpret_cast &>(this->MLCTypingPtr::ty)); }
};
struct PtrType : public Type {
MLC_DEF_OBJ_REF(PtrType, PtrTypeObj, Type)
.StaticFn("__init__", InitOf)
- .MemFn("_ty", &PtrTypeObj::Ty)
+ ._Field("ty", offsetof(MLCTypingPtr, ty), sizeof(MLCTypingPtr::ty), false, ParseType())
.MemFn("__str__", &PtrTypeObj::__str__)
.MemFn("__cxx_str__", &PtrTypeObj::__cxx_str__);
};
struct OptionalObj : protected MLCTypingOptional {
explicit OptionalObj(Type ty) : MLCTypingOptional{} { this->TyMutable() = ty; }
- Type Ty() const { return Type(reinterpret_cast &>(this->MLCTypingOptional::ty)); }
::mlc::Str __str__() const {
std::ostringstream os;
os << this->Ty() << " | None";
@@ -167,19 +166,19 @@ struct OptionalObj : protected MLCTypingOptional {
private:
Type &TyMutable() { return reinterpret_cast(this->MLCTypingOptional::ty); }
+ Type Ty() const { return Type(reinterpret_cast &>(this->MLCTypingOptional::ty)); }
};
struct Optional : public Type {
MLC_DEF_OBJ_REF(Optional, OptionalObj, Type)
.StaticFn("__init__", InitOf)
- .MemFn("_ty", &OptionalObj::Ty)
+ ._Field("ty", offsetof(MLCTypingOptional, ty), sizeof(MLCTypingOptional::ty), false, ParseType())
.MemFn("__str__", &OptionalObj::__str__)
.MemFn("__cxx_str__", &OptionalObj::__cxx_str__);
};
struct ListObj : protected MLCTypingList {
explicit ListObj(Type ty) : MLCTypingList{} { this->TyMutable() = ty; }
- Type Ty() const { return Type(reinterpret_cast &>(this->MLCTypingList::ty)); }
::mlc::Str __str__() const {
std::ostringstream os;
os << "list[" << this->Ty() << "]";
@@ -195,31 +194,30 @@ struct ListObj : protected MLCTypingList {
protected:
Type &TyMutable() { return reinterpret_cast(this->MLCTypingList::ty); }
+ Type Ty() const { return Type(reinterpret_cast &>(this->MLCTypingList::ty)); }
};
struct List : public Type {
MLC_DEF_OBJ_REF(List, ListObj, Type)
.StaticFn("__init__", InitOf)
- .MemFn("_ty", &ListObj::Ty)
+ ._Field("ty", offsetof(MLCTypingList, ty), sizeof(MLCTypingList::ty), false, ParseType())
.MemFn("__str__", &ListObj::__str__)
.MemFn("__cxx_str__", &ListObj::__cxx_str__);
};
struct DictObj : protected MLCTypingDict {
explicit DictObj(Type ty_k, Type ty_v) : MLCTypingDict{} {
- this->TyMutableK() = ty_k;
- this->TyMutableV() = ty_v;
+ this->TyKMut() = ty_k;
+ this->TyVMut() = ty_v;
}
- Type key() const { return Type(reinterpret_cast &>(this->ty_k)); }
- Type value() const { return Type(reinterpret_cast &>(this->ty_v)); }
::mlc::Str __str__() const {
std::ostringstream os;
- os << "dict[" << this->key() << ", " << this->value() << "]";
+ os << "dict[" << this->TyK() << ", " << this->TyV() << "]";
return os.str();
}
::mlc::Str __cxx_str__() const {
- ::mlc::Str k_str = ::mlc::base::LibState::CxxStr(this->key());
- ::mlc::Str v_str = ::mlc::base::LibState::CxxStr(this->value());
+ ::mlc::Str k_str = ::mlc::base::LibState::CxxStr(this->TyK());
+ ::mlc::Str v_str = ::mlc::base::LibState::CxxStr(this->TyV());
std::ostringstream os;
os << "::mlc::Dict<" << k_str->data() << ", " << v_str->data() << ">";
return os.str();
@@ -227,15 +225,17 @@ struct DictObj : protected MLCTypingDict {
MLC_DEF_STATIC_TYPE(DictObj, TypeObj, MLCTypeIndex::kMLCTypingDict, "mlc.core.typing.Dict");
protected:
- Type &TyMutableK() { return reinterpret_cast(this->ty_k); }
- Type &TyMutableV() { return reinterpret_cast(this->ty_v); }
+ Type &TyKMut() { return reinterpret_cast(this->ty_k); }
+ Type &TyVMut() { return reinterpret_cast(this->ty_v); }
+ Type TyK() const { return Type(reinterpret_cast &>(this->ty_k)); }
+ Type TyV() const { return Type(reinterpret_cast &>(this->ty_v)); }
};
struct Dict : public Type {
MLC_DEF_OBJ_REF(Dict, DictObj, Type)
.StaticFn("__init__", InitOf)
- .MemFn("_key", &DictObj::key)
- .MemFn("_value", &DictObj::value)
+ ._Field("ty_k", offsetof(MLCTypingDict, ty_k), sizeof(MLCTypingDict::ty_k), false, ParseType())
+ ._Field("ty_v", offsetof(MLCTypingDict, ty_v), sizeof(MLCTypingDict::ty_v), false, ParseType())
.MemFn("__str__", &DictObj::__str__)
.MemFn("__cxx_str__", &DictObj::__cxx_str__);
};
diff --git a/include/mlc/core/utils.h b/include/mlc/core/utils.h
index 3bf2815e..67f86eab 100644
--- a/include/mlc/core/utils.h
+++ b/include/mlc/core/utils.h
@@ -128,6 +128,13 @@ struct ReflectionHelper {
return *this;
}
+ inline ReflectionHelper &_Field(const char *name, int64_t field_offset, int32_t num_bytes, bool frozen, Any ty) {
+ this->any_pool.push_back(ty);
+ int32_t index = static_cast(this->fields.size());
+ this->fields.emplace_back(MLCTypeField{name, index, field_offset, num_bytes, frozen, ty.v.v_obj});
+ return *this;
+ }
+
template inline ReflectionHelper &MemFn(const char *name, Callable &&method) {
MLCTypeMethod m = this->PrepareMethod(name, std::forward(method));
m.kind = kMemFn;
diff --git a/include/mlc/printer/ir_printer.h b/include/mlc/printer/ir_printer.h
index c4cdf52e..83213948 100644
--- a/include/mlc/printer/ir_printer.h
+++ b/include/mlc/printer/ir_printer.h
@@ -49,7 +49,7 @@ struct IRPrinterObj : public Object {
bool VarIsDefined(const ObjectRef &obj) { return obj2info->count(obj) > 0; }
- Id VarDef(const ObjectRef &obj, const ObjectRef &frame, Str name_hint) {
+ Id VarDef(Str name_hint, const ObjectRef &obj, const Optional &frame) {
if (auto it = obj2info.find(obj); it != obj2info.end()) {
Optional name = (*it).second->name;
return Id(name.value());
@@ -66,18 +66,19 @@ struct IRPrinterObj : public Object {
name = name_hint.ToStdString() + '_' + std::to_string(i);
}
defined_names->Set(name, 1);
- this->_VarDef(obj, frame, VarInfo(name, Func([name]() { return Id(name); })));
+ this->_VarDef(VarInfo(name, Func([name]() { return Id(name); })), obj, frame);
return Id(name);
}
- void VarDefNoName(const ObjectRef &obj, const ObjectRef &frame, const Func &creator) {
+ void VarDefNoName(const Func &creator, const ObjectRef &obj, const Optional &frame) {
if (obj2info.count(obj) > 0) {
MLC_THROW(KeyError) << "Variable already defined: " << obj;
}
- this->_VarDef(obj, frame, VarInfo(mlc::Null, creator));
+ this->_VarDef(VarInfo(mlc::Null, creator), obj, frame);
}
- void _VarDef(const ObjectRef &obj, const ObjectRef &frame, VarInfo var_info) {
+ void _VarDef(VarInfo var_info, const ObjectRef &obj, const Optional &_frame) {
+ ObjectRef frame = _frame.defined() ? _frame.value() : this->frames.back().operator ObjectRef();
obj2info->Set(obj, var_info);
auto it = frame_vars.find(frame);
if (it == frame_vars.end()) {
@@ -99,7 +100,7 @@ struct IRPrinterObj : public Object {
obj2info.erase(it);
}
- Optional VarGet(const ObjectRef &obj) {
+ Optional VarGet(const ObjectRef &obj) {
auto it = obj2info.find(obj);
if (it == obj2info.end()) {
return Null;
diff --git a/pyproject.toml b/pyproject.toml
index 9dd02b0b..02cabca4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,10 +1,11 @@
[project]
name = "mlc-python"
-version = "0.0.13"
+version = "0.0.14"
dependencies = [
'numpy >= 1.22',
- "ml-dtypes >= 0.1",
+ 'ml-dtypes >= 0.1',
'Pygments>=2.4.0',
+ 'colorama',
'setuptools ; platform_system == "Windows"',
]
description = ""
diff --git a/python/mlc/__init__.py b/python/mlc/__init__.py
index 148dd6a7..8c060f2c 100644
--- a/python/mlc/__init__.py
+++ b/python/mlc/__init__.py
@@ -1,4 +1,4 @@
-from . import _cython, ast, cc, dataclasses, printer
+from . import _cython, cc, dataclasses, parser, printer
from ._cython import Ptr, Str
from .core import DataType, Device, Dict, Error, Func, List, Object, ObjectPath, typing
from .dataclasses import PyClass, c_class, py_class
diff --git a/python/mlc/_cython/core.pyx b/python/mlc/_cython/core.pyx
index 0a1fdd52..93dd60fc 100644
--- a/python/mlc/_cython/core.pyx
+++ b/python/mlc/_cython/core.pyx
@@ -1360,12 +1360,16 @@ def make_mlc_init(list fields):
cdef tuple setters = _setters
cdef int32_t num_args = len(args)
cdef int32_t i = 0
+ cdef object e = None
assert num_args == len(setters)
while i < num_args:
try:
setters[i](self, args[i])
- except Exception as e: # no-cython-lint
- raise ValueError(f"Failed to set field `{fields[i].name}`: {str(e)}. Got: {args[i]}")
+ except Exception as _e: # no-cython-lint
+ e = ValueError(f"Failed to set field `{fields[i].name}`: {str(_e)}. Got: {args[i]}")
+ e = e.with_traceback(_e.__traceback__)
+ if e is not None:
+ raise e
i += 1
return _mlc_init
diff --git a/python/mlc/ast/__init__.py b/python/mlc/ast/__init__.py
deleted file mode 100644
index 402b60b4..00000000
--- a/python/mlc/ast/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from . import translate
-from .inspection import InspectResult, inspect_program
-from .mlc_ast import * # noqa: F403
-from .translate import translate_mlc_to_py, translate_py_to_mlc
diff --git a/python/mlc/ast/mlc_ast.py b/python/mlc/ast/mlc_ast.py
deleted file mode 100644
index 72bb7861..00000000
--- a/python/mlc/ast/mlc_ast.py
+++ /dev/null
@@ -1,855 +0,0 @@
-# Derived from
-# https://github.com/python/typeshed/blob/main/stdlib/ast.pyi
-
-from typing import Any, Optional
-
-from mlc import dataclasses as mlcd
-
-_Identifier = str
-
-
-@mlcd.py_class(type_key="mlc.ast.AST", structure="bind")
-class AST(mlcd.PyClass): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.mod", structure="bind")
-class mod(AST): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.expr_context", structure="bind")
-class expr_context(AST): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.operator", structure="bind")
-class operator(AST): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.cmpop", structure="bind")
-class cmpop(AST): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.unaryop", structure="bind")
-class unaryop(AST): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.boolop", structure="bind")
-class boolop(AST): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.type_ignore", structure="bind")
-class type_ignore(AST): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.stmt", structure="bind")
-class stmt(AST):
- lineno: Optional[int]
- col_offset: Optional[int]
- end_lineno: Optional[int]
- end_col_offset: Optional[int]
-
-
-@mlcd.py_class(type_key="mlc.ast.expr", structure="bind")
-class expr(AST):
- lineno: Optional[int]
- col_offset: Optional[int]
- end_lineno: Optional[int]
- end_col_offset: Optional[int]
-
-
-@mlcd.py_class(type_key="mlc.ast.type_param", structure="bind")
-class type_param(AST):
- lineno: Optional[int]
- col_offset: Optional[int]
- end_lineno: Optional[int]
- end_col_offset: Optional[int]
-
-
-@mlcd.py_class(type_key="mlc.ast.pattern", structure="bind")
-class pattern(AST):
- lineno: Optional[int]
- col_offset: Optional[int]
- end_lineno: Optional[int]
- end_col_offset: Optional[int]
-
-
-@mlcd.py_class(type_key="mlc.ast.arg", structure="bind")
-class arg(AST):
- lineno: Optional[int]
- col_offset: Optional[int]
- end_lineno: Optional[int]
- end_col_offset: Optional[int]
- arg: _Identifier
- annotation: Optional[expr]
- type_comment: Optional[str]
-
-
-@mlcd.py_class(type_key="mlc.ast.keyword", structure="bind")
-class keyword(AST):
- lineno: Optional[int]
- col_offset: Optional[int]
- end_lineno: Optional[int]
- end_col_offset: Optional[int]
- arg: Optional[_Identifier]
- value: expr
-
-
-@mlcd.py_class(type_key="mlc.ast.alias", structure="bind")
-class alias(AST):
- lineno: Optional[int]
- col_offset: Optional[int]
- end_lineno: Optional[int]
- end_col_offset: Optional[int]
- name: str
- asname: Optional[_Identifier]
-
-
-@mlcd.py_class(type_key="mlc.ast.arguments", structure="bind")
-class arguments(AST):
- posonlyargs: list[arg]
- args: list[arg]
- vararg: Optional[arg]
- kwonlyargs: list[arg]
- kw_defaults: list[Optional[expr]]
- kwarg: Optional[arg]
- defaults: list[expr]
-
-
-@mlcd.py_class(type_key="mlc.ast.comprehension", structure="bind")
-class comprehension(AST):
- target: expr
- iter: expr
- ifs: list[expr]
- is_async: int
-
-
-@mlcd.py_class(type_key="mlc.ast.withitem", structure="bind")
-class withitem(AST):
- context_expr: expr
- optional_vars: Optional[expr]
-
-
-@mlcd.py_class(type_key="mlc.ast.TypeIgnore", structure="bind")
-class TypeIgnore(type_ignore):
- lineno: Optional[int]
- tag: str
-
-
-@mlcd.py_class(type_key="mlc.ast.Module", structure="bind")
-class Module(mod):
- body: list[stmt]
- type_ignores: list[TypeIgnore]
-
-
-@mlcd.py_class(type_key="mlc.ast.Interactive", structure="bind")
-class Interactive(mod):
- body: list[stmt]
-
-
-@mlcd.py_class(type_key="mlc.ast.Expression", structure="bind")
-class Expression(mod):
- body: expr
-
-
-@mlcd.py_class(type_key="mlc.ast.FunctionType", structure="bind")
-class FunctionType(mod):
- argtypes: list[expr]
- returns: expr
-
-
-@mlcd.py_class(type_key="mlc.ast.FunctionDef", structure="bind")
-class FunctionDef(stmt):
- name: _Identifier
- args: arguments
- body: list[stmt]
- decorator_list: list[expr]
- returns: Optional[expr]
- type_comment: Optional[str]
- type_params: Optional[list[type_param]]
-
-
-@mlcd.py_class(type_key="mlc.ast.AsyncFunctionDef", structure="bind")
-class AsyncFunctionDef(stmt):
- name: _Identifier
- args: arguments
- body: list[stmt]
- decorator_list: list[expr]
- returns: Optional[expr]
- type_comment: Optional[str]
- type_params: Optional[list[type_param]]
-
-
-@mlcd.py_class(type_key="mlc.ast.ClassDef", structure="bind")
-class ClassDef(stmt):
- name: _Identifier
- bases: list[expr]
- keywords: list[keyword]
- body: list[stmt]
- decorator_list: list[expr]
- type_params: Optional[list[type_param]]
-
-
-@mlcd.py_class(type_key="mlc.ast.Return", structure="bind")
-class Return(stmt):
- value: Optional[expr]
-
-
-@mlcd.py_class(type_key="mlc.ast.Delete", structure="bind")
-class Delete(stmt):
- targets: list[expr]
-
-
-@mlcd.py_class(type_key="mlc.ast.Assign", structure="bind")
-class Assign(stmt):
- targets: list[expr]
- value: expr
- type_comment: Optional[str]
-
-
-@mlcd.py_class(type_key="mlc.ast.Attribute", structure="bind")
-class Attribute(expr):
- value: expr
- attr: _Identifier
- ctx: expr_context
-
-
-@mlcd.py_class(type_key="mlc.ast.Subscript", structure="bind")
-class Subscript(expr):
- value: expr
- slice: expr
- ctx: expr_context
-
-
-@mlcd.py_class(type_key="mlc.ast.AugAssign", structure="bind")
-class AugAssign(stmt):
- target: Any # Name | Attribute | Subscript
- op: operator
- value: expr
-
-
-@mlcd.py_class(type_key="mlc.ast.AnnAssign", structure="bind")
-class AnnAssign(stmt):
- target: Any # Name | Attribute | Subscript
- annotation: expr
- value: Optional[expr]
- simple: int
-
-
-@mlcd.py_class(type_key="mlc.ast.For", structure="bind")
-class For(stmt):
- target: expr
- iter: expr
- body: list[stmt]
- orelse: list[stmt]
- type_comment: Optional[str]
-
-
-@mlcd.py_class(type_key="mlc.ast.AsyncFor", structure="bind")
-class AsyncFor(stmt):
- target: expr
- iter: expr
- body: list[stmt]
- orelse: list[stmt]
- type_comment: Optional[str]
-
-
-@mlcd.py_class(type_key="mlc.ast.While", structure="bind")
-class While(stmt):
- test: expr
- body: list[stmt]
- orelse: list[stmt]
-
-
-@mlcd.py_class(type_key="mlc.ast.If", structure="bind")
-class If(stmt):
- test: expr
- body: list[stmt]
- orelse: list[stmt]
-
-
-@mlcd.py_class(type_key="mlc.ast.With", structure="bind")
-class With(stmt):
- items: list[withitem]
- body: list[stmt]
- type_comment: Optional[str]
-
-
-@mlcd.py_class(type_key="mlc.ast.AsyncWith", structure="bind")
-class AsyncWith(stmt):
- items: list[withitem]
- body: list[stmt]
- type_comment: Optional[str]
-
-
-@mlcd.py_class(type_key="mlc.ast.match_case", structure="bind")
-class match_case(AST):
- pattern: pattern
- guard: Optional[expr]
- body: list[stmt]
-
-
-@mlcd.py_class(type_key="mlc.ast.Match", structure="bind")
-class Match(stmt):
- subject: expr
- cases: list[match_case]
-
-
-@mlcd.py_class(type_key="mlc.ast.Raise", structure="bind")
-class Raise(stmt):
- exc: Optional[expr]
- cause: Optional[expr]
-
-
-@mlcd.py_class(type_key="mlc.ast.ExceptHandler", structure="bind")
-class ExceptHandler(AST):
- lineno: Optional[int]
- col_offset: Optional[int]
- end_lineno: Optional[int]
- end_col_offset: Optional[int]
- type: Optional[expr]
- name: Optional[_Identifier]
- body: list[stmt]
-
-
-@mlcd.py_class(type_key="mlc.ast.Try", structure="bind")
-class Try(stmt):
- body: list[stmt]
- handlers: list[ExceptHandler]
- orelse: list[stmt]
- finalbody: list[stmt]
-
-
-@mlcd.py_class(type_key="mlc.ast.TryStar", structure="bind")
-class TryStar(stmt):
- body: list[stmt]
- handlers: list[ExceptHandler]
- orelse: list[stmt]
- finalbody: list[stmt]
-
-
-@mlcd.py_class(type_key="mlc.ast.Assert", structure="bind")
-class Assert(stmt):
- test: expr
- msg: Optional[expr]
-
-
-@mlcd.py_class(type_key="mlc.ast.Import", structure="bind")
-class Import(stmt):
- names: list[alias]
-
-
-@mlcd.py_class(type_key="mlc.ast.ImportFrom", structure="bind")
-class ImportFrom(stmt):
- module: Optional[str]
- names: list[alias]
- level: int
-
-
-@mlcd.py_class(type_key="mlc.ast.Global", structure="bind")
-class Global(stmt):
- names: list[_Identifier]
-
-
-@mlcd.py_class(type_key="mlc.ast.Nonlocal", structure="bind")
-class Nonlocal(stmt):
- names: list[_Identifier]
-
-
-@mlcd.py_class(type_key="mlc.ast.Expr", structure="bind")
-class Expr(stmt):
- value: expr
-
-
-@mlcd.py_class(type_key="mlc.ast.Pass", structure="bind")
-class Pass(stmt): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.Break", structure="bind")
-class Break(stmt): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.Continue", structure="bind")
-class Continue(stmt): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.BoolOp", structure="bind")
-class BoolOp(expr):
- op: boolop
- values: list[expr]
-
-
-@mlcd.py_class(type_key="mlc.ast.Name", structure="bind")
-class Name(expr):
- id: _Identifier
- ctx: expr_context
-
-
-@mlcd.py_class(type_key="mlc.ast.NamedExpr", structure="bind")
-class NamedExpr(expr):
- target: Name
- value: expr
-
-
-@mlcd.py_class(type_key="mlc.ast.BinOp", structure="bind")
-class BinOp(expr):
- left: expr
- op: operator
- right: expr
-
-
-@mlcd.py_class(type_key="mlc.ast.UnaryOp", structure="bind")
-class UnaryOp(expr):
- op: unaryop
- operand: expr
-
-
-@mlcd.py_class(type_key="mlc.ast.Lambda", structure="bind")
-class Lambda(expr):
- args: arguments
- body: expr
-
-
-@mlcd.py_class(type_key="mlc.ast.IfExp", structure="bind")
-class IfExp(expr):
- test: expr
- body: expr
- orelse: expr
-
-
-@mlcd.py_class(type_key="mlc.ast.Dict", structure="bind")
-class Dict(expr):
- keys: list[Optional[expr]]
- values: list[expr]
-
-
-@mlcd.py_class(type_key="mlc.ast.Set", structure="bind")
-class Set(expr):
- elts: list[expr]
-
-
-@mlcd.py_class(type_key="mlc.ast.ListComp", structure="bind")
-class ListComp(expr):
- elt: expr
- generators: list[comprehension]
-
-
-@mlcd.py_class(type_key="mlc.ast.SetComp", structure="bind")
-class SetComp(expr):
- elt: expr
- generators: list[comprehension]
-
-
-@mlcd.py_class(type_key="mlc.ast.DictComp", structure="bind")
-class DictComp(expr):
- key: expr
- value: expr
- generators: list[comprehension]
-
-
-@mlcd.py_class(type_key="mlc.ast.GeneratorExp", structure="bind")
-class GeneratorExp(expr):
- elt: expr
- generators: list[comprehension]
-
-
-@mlcd.py_class(type_key="mlc.ast.Await", structure="bind")
-class Await(expr):
- value: expr
-
-
-@mlcd.py_class(type_key="mlc.ast.Yield", structure="bind")
-class Yield(expr):
- value: Optional[expr]
-
-
-@mlcd.py_class(type_key="mlc.ast.YieldFrom", structure="bind")
-class YieldFrom(expr):
- value: expr
-
-
-@mlcd.py_class(type_key="mlc.ast.Compare", structure="bind")
-class Compare(expr):
- left: expr
- ops: list[cmpop]
- comparators: list[expr]
-
-
-@mlcd.py_class(type_key="mlc.ast.Call", structure="bind")
-class Call(expr):
- func: expr
- args: list[expr]
- keywords: list[keyword]
-
-
-@mlcd.py_class(type_key="mlc.ast.FormattedValue", structure="bind")
-class FormattedValue(expr):
- value: expr
- conversion: int
- format_spec: Optional[expr]
-
-
-@mlcd.py_class(type_key="mlc.ast.JoinedStr", structure="bind")
-class JoinedStr(expr):
- values: list[expr]
-
-
-@mlcd.py_class(type_key="mlc.ast.Ellipsis", structure="bind")
-class Ellipsis(mlcd.PyClass): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.Constant", structure="bind")
-class Constant(expr):
- value: Any # None, str, bytes, bool, int, float, complex, Ellipsis
- kind: Optional[str]
-
-
-@mlcd.py_class(type_key="mlc.ast.Starred", structure="bind")
-class Starred(expr):
- value: expr
- ctx: expr_context
-
-
-@mlcd.py_class(type_key="mlc.ast.List", structure="bind")
-class List(expr):
- elts: list[expr]
- ctx: expr_context
-
-
-@mlcd.py_class(type_key="mlc.ast.Tuple", structure="bind")
-class Tuple(expr):
- elts: list[expr]
- ctx: expr_context
- dims: list[expr]
-
-
-@mlcd.py_class(type_key="mlc.ast.Slice", structure="bind")
-class Slice(expr):
- lower: Optional[expr]
- upper: Optional[expr]
- step: Optional[expr]
-
-
-@mlcd.py_class(type_key="mlc.ast.Load", structure="bind")
-class Load(expr_context): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.Store", structure="bind")
-class Store(expr_context): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.Del", structure="bind")
-class Del(expr_context): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.And", structure="bind")
-class And(boolop): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.Or", structure="bind")
-class Or(boolop): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.Add", structure="bind")
-class Add(operator): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.Sub", structure="bind")
-class Sub(operator): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.Mult", structure="bind")
-class Mult(operator): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.MatMult", structure="bind")
-class MatMult(operator): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.Div", structure="bind")
-class Div(operator): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.Mod", structure="bind")
-class Mod(operator): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.Pow", structure="bind")
-class Pow(operator): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.LShift", structure="bind")
-class LShift(operator): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.RShift", structure="bind")
-class RShift(operator): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.BitOr", structure="bind")
-class BitOr(operator): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.BitXor", structure="bind")
-class BitXor(operator): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.BitAnd", structure="bind")
-class BitAnd(operator): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.FloorDiv", structure="bind")
-class FloorDiv(operator): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.Invert", structure="bind")
-class Invert(unaryop): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.Not", structure="bind")
-class Not(unaryop): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.UAdd", structure="bind")
-class UAdd(unaryop): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.USub", structure="bind")
-class USub(unaryop): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.Eq", structure="bind")
-class Eq(cmpop): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.NotEq", structure="bind")
-class NotEq(cmpop): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.Lt", structure="bind")
-class Lt(cmpop): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.LtE", structure="bind")
-class LtE(cmpop): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.Gt", structure="bind")
-class Gt(cmpop): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.GtE", structure="bind")
-class GtE(cmpop): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.Is", structure="bind")
-class Is(cmpop): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.IsNot", structure="bind")
-class IsNot(cmpop): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.In", structure="bind")
-class In(cmpop): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.NotIn", structure="bind")
-class NotIn(cmpop): ...
-
-
-@mlcd.py_class(type_key="mlc.ast.MatchValue", structure="bind")
-class MatchValue(pattern):
- value: expr
-
-
-@mlcd.py_class(type_key="mlc.ast.MatchSingleton", structure="bind")
-class MatchSingleton(pattern):
- value: int # boolean
-
-
-@mlcd.py_class(type_key="mlc.ast.MatchSequence", structure="bind")
-class MatchSequence(pattern):
- patterns: list[pattern]
-
-
-@mlcd.py_class(type_key="mlc.ast.MatchMapping", structure="bind")
-class MatchMapping(pattern):
- keys: list[expr]
- patterns: list[pattern]
- rest: Optional[_Identifier]
-
-
-@mlcd.py_class(type_key="mlc.ast.MatchClass", structure="bind")
-class MatchClass(pattern):
- cls: expr
- patterns: list[pattern]
- kwd_attrs: list[_Identifier]
- kwd_patterns: list[pattern]
-
-
-@mlcd.py_class(type_key="mlc.ast.MatchStar", structure="bind")
-class MatchStar(pattern):
- name: Optional[_Identifier]
-
-
-@mlcd.py_class(type_key="mlc.ast.MatchAs", structure="bind")
-class MatchAs(pattern):
- pattern: Optional[pattern]
- name: Optional[_Identifier]
-
-
-@mlcd.py_class(type_key="mlc.ast.MatchOr", structure="bind")
-class MatchOr(pattern):
- patterns: list[pattern]
-
-
-@mlcd.py_class(type_key="mlc.ast.TypeVar", structure="bind")
-class TypeVar(type_param):
- name: _Identifier
- bound: Optional[expr]
- default_value: Optional[expr]
-
-
-@mlcd.py_class(type_key="mlc.ast.ParamSpec", structure="bind")
-class ParamSpec(type_param):
- name: _Identifier
- default_value: Optional[expr]
-
-
-@mlcd.py_class(type_key="mlc.ast.TypeVarTuple", structure="bind")
-class TypeVarTuple(type_param):
- name: _Identifier
- default_value: Optional[expr]
-
-
-@mlcd.py_class(type_key="mlc.ast.TypeAlias", structure="bind")
-class TypeAlias(stmt):
- name: Name
- type_params: Optional[list[type_param]]
- value: expr
-
-
-__all__ = [
- "AST",
- "mod",
- "expr_context",
- "operator",
- "cmpop",
- "unaryop",
- "boolop",
- "type_ignore",
- "stmt",
- "expr",
- "type_param",
- "pattern",
- "arg",
- "keyword",
- "alias",
- "arguments",
- "comprehension",
- "withitem",
- "TypeIgnore",
- "Module",
- "Interactive",
- "Expression",
- "FunctionType",
- "FunctionDef",
- "AsyncFunctionDef",
- "ClassDef",
- "Return",
- "Delete",
- "Assign",
- "Attribute",
- "Subscript",
- "AugAssign",
- "AnnAssign",
- "For",
- "AsyncFor",
- "While",
- "If",
- "With",
- "AsyncWith",
- "match_case",
- "Match",
- "Raise",
- "ExceptHandler",
- "Try",
- "TryStar",
- "Assert",
- "Import",
- "ImportFrom",
- "Global",
- "Nonlocal",
- "Expr",
- "Pass",
- "Break",
- "Continue",
- "BoolOp",
- "Name",
- "NamedExpr",
- "BinOp",
- "UnaryOp",
- "Lambda",
- "IfExp",
- "Dict",
- "Set",
- "ListComp",
- "SetComp",
- "DictComp",
- "GeneratorExp",
- "Await",
- "Yield",
- "YieldFrom",
- "Compare",
- "Call",
- "FormattedValue",
- "JoinedStr",
- "Constant",
- "Starred",
- "List",
- "Tuple",
- "Slice",
- "Load",
- "Store",
- "Del",
- "And",
- "Or",
- "Add",
- "Sub",
- "Mult",
- "MatMult",
- "Div",
- "Mod",
- "Pow",
- "LShift",
- "RShift",
- "BitOr",
- "BitXor",
- "BitAnd",
- "FloorDiv",
- "Invert",
- "Not",
- "UAdd",
- "USub",
- "Eq",
- "NotEq",
- "Lt",
- "LtE",
- "Gt",
- "GtE",
- "Is",
- "IsNot",
- "In",
- "NotIn",
- "MatchValue",
- "MatchSingleton",
- "MatchSequence",
- "MatchMapping",
- "MatchClass",
- "MatchStar",
- "MatchAs",
- "MatchOr",
- "TypeVar",
- "ParamSpec",
- "TypeVarTuple",
- "TypeAlias",
-]
diff --git a/python/mlc/ast/translate.py b/python/mlc/ast/translate.py
deleted file mode 100644
index cf4010ce..00000000
--- a/python/mlc/ast/translate.py
+++ /dev/null
@@ -1,109 +0,0 @@
-from __future__ import annotations
-
-import ast
-import typing
-
-from mlc.core import List
-
-from . import mlc_ast
-
-TYPE_PY_TO_MLC_TRANSLATOR = typing.Callable[[ast.AST], mlc_ast.AST]
-TYPE_MLC_TO_PY_TRANSLATOR = typing.Callable[[mlc_ast.AST], ast.AST]
-
-
-def translate_py_to_mlc(py_node: ast.AST) -> mlc_ast.AST:
- assert isinstance(py_node, ast.AST), f"Expected AST node, got {type(py_node)}"
- if (translator := _PY_TO_MLC_VTABLE.get(type(py_node))) is not None:
- return translator(py_node)
- raise NotImplementedError(f"Translation not implemented: {type(py_node)}")
-
-
-def translate_mlc_to_py(mlc_node: mlc_ast.AST) -> ast.AST:
- assert isinstance(mlc_node, mlc_ast.AST), f"Expected AST node, got {type(mlc_node)}"
- if (translator := _MLC_TO_PY_VTABLE.get(type(mlc_node))) is not None:
- return translator(mlc_node)
- raise NotImplementedError(f"Translation not implemented: {type(mlc_node)}")
-
-
-def py_to_mlc_vtable_create() -> dict[type[ast.AST], TYPE_PY_TO_MLC_TRANSLATOR]:
- def _translate_field(value: typing.Any) -> typing.Any:
- if isinstance(value, ast.AST):
- return translate_py_to_mlc(value)
- elif value is ...:
- return mlc_ast.Ellipsis()
- elif isinstance(value, list):
- return [_translate_field(item) for item in value]
- return value
-
- def _create_entry(
- mlc_cls: type[mlc_ast.AST], py_cls: type[ast.AST]
- ) -> TYPE_PY_TO_MLC_TRANSLATOR:
- def py_to_mlc_default(py_node: ast.AST) -> mlc_ast.AST:
- return mlc_cls(
- **{
- field: _translate_field(getattr(py_node, field, None))
- for field in mlc_field_names
- }
- )
-
- mlc_field_names = list(typing.get_type_hints(mlc_cls).keys())
- return py_to_mlc_default
-
- def _create_entry_constant() -> TYPE_PY_TO_MLC_TRANSLATOR:
- # Special handling to convert `bytea` to `str`, because `bytes` is not supported in MLC
- entry = _create_entry(mlc_ast.Constant, ast.Constant)
-
- def py_to_mlc_constant(py_node: ast.Constant) -> mlc_ast.Constant:
- if isinstance(py_node.value, bytes):
- py_node.value = py_node.value.decode("utf-8")
- return entry(py_node)
-
- return typing.cast(TYPE_PY_TO_MLC_TRANSLATOR, py_to_mlc_constant)
-
- vtable: dict[type[ast.AST], TYPE_PY_TO_MLC_TRANSLATOR] = {
- type(...): lambda _: mlc_ast.Ellipsis(), # type: ignore
- ast.Constant: _create_entry_constant(),
- }
- for cls_name in mlc_ast.__all__:
- mlc_cls = getattr(mlc_ast, cls_name)
- if py_cls := getattr(ast, cls_name, None):
- if py_cls not in vtable:
- vtable[py_cls] = _create_entry(mlc_cls, py_cls)
- return vtable
-
-
-def mlc_to_py_vtable_create() -> dict[type[mlc_ast.AST], TYPE_MLC_TO_PY_TRANSLATOR]:
- def _translate_field(value: typing.Any) -> typing.Any:
- if isinstance(value, mlc_ast.AST):
- return translate_mlc_to_py(value)
- elif isinstance(value, (list, List)):
- return [_translate_field(item) for item in value]
- return value
-
- def _create_entry(
- py_cls: type[ast.AST], mlc_cls: type[mlc_ast.AST]
- ) -> TYPE_MLC_TO_PY_TRANSLATOR:
- def mlc_to_py_default(mlc_node: mlc_ast.AST) -> ast.AST:
- return py_cls(
- **{
- field: _translate_field(getattr(mlc_node, field, None))
- for field in py_field_names
- }
- )
-
- py_field_names = list(typing.get_type_hints(mlc_cls).keys())
- return mlc_to_py_default
-
- vtable: dict[type[mlc_ast.AST], TYPE_MLC_TO_PY_TRANSLATOR] = {
- mlc_ast.Ellipsis: lambda _: ..., # type: ignore
- }
- for cls_name in mlc_ast.__all__:
- mlc_cls = getattr(mlc_ast, cls_name)
- if py_cls := getattr(ast, cls_name, None):
- if mlc_cls not in vtable:
- vtable[mlc_cls] = _create_entry(py_cls, mlc_cls)
- return vtable
-
-
-_PY_TO_MLC_VTABLE = py_to_mlc_vtable_create()
-_MLC_TO_PY_VTABLE = mlc_to_py_vtable_create()
diff --git a/python/mlc/core/typing.py b/python/mlc/core/typing.py
index 7b38ade8..a3564563 100644
--- a/python/mlc/core/typing.py
+++ b/python/mlc/core/typing.py
@@ -63,9 +63,10 @@ def _ctype(self) -> typing.Any:
@c_class_core("mlc.core.typing.PtrType")
class PtrType(Type):
- @property
- def ty(self) -> Type:
- return self._C(b"_ty", self)
+ ty: Type
+
+ def __init__(self, ty: Type) -> None:
+ self._mlc_init(ty)
def args(self) -> tuple:
return (self.ty,)
@@ -76,13 +77,11 @@ def _ctype(self) -> typing.Any:
@c_class_core("mlc.core.typing.Optional")
class Optional(Type):
+ ty: Type
+
def __init__(self, ty: Type) -> None:
self._mlc_init(ty)
- @property
- def ty(self) -> Type:
- return self._C(b"_ty", self)
-
def args(self) -> tuple:
return (self.ty,)
@@ -92,13 +91,11 @@ def _ctype(self) -> typing.Any:
@c_class_core("mlc.core.typing.List")
class List(Type):
+ ty: Type
+
def __init__(self, ty: Type) -> None:
self._mlc_init(ty)
- @property
- def ty(self) -> Type:
- return self._C(b"_ty", self)
-
def args(self) -> tuple:
return (self.ty,)
@@ -108,19 +105,14 @@ def _ctype(self) -> typing.Any:
@c_class_core("mlc.core.typing.Dict")
class Dict(Type):
- def __init__(self, key_ty: Type, value_ty: Type) -> None:
- self._mlc_init(key_ty, value_ty)
-
- @property
- def key(self) -> Type:
- return self._C(b"_key", self)
+ ty_k: Type
+ ty_v: Type
- @property
- def value(self) -> Type:
- return self._C(b"_value", self)
+ def __init__(self, ty_k: Type, ty_v: Type) -> None:
+ self._mlc_init(ty_k, ty_v)
def args(self) -> tuple:
- return (self.key, self.value)
+ return (self.ty_k, self.ty_v)
def _ctype(self) -> typing.Any:
return MLCObjPtr
diff --git a/python/mlc/dataclasses/utils.py b/python/mlc/dataclasses/utils.py
index 91b94311..1fa3a37d 100644
--- a/python/mlc/dataclasses/utils.py
+++ b/python/mlc/dataclasses/utils.py
@@ -78,13 +78,16 @@ def bind_args(*args: typing.Any, **kwargs: typing.Any) -> inspect.BoundArguments
return bound
def method(self: type, *args: typing.Any, **kwargs: typing.Any) -> None:
+ e = None
try:
args = bind_args(*args, **kwargs).args
args = tuple(arg.fn() if isinstance(arg, DefaultFactory) else arg for arg in args)
args = tuple(args[order] for order in ordering)
self._mlc_init(*args) # type: ignore[attr-defined]
- except Exception as e:
- raise TypeError(f"Error in `{signature_str}`: {e}") # type: ignore[attr-defined]
+ except Exception as _e:
+ e = TypeError(f"Error in `{signature_str}`: {_e}").with_traceback(_e.__traceback__)
+ if e is not None:
+ raise e
try:
post_init = self.__post_init__ # type: ignore[attr-defined]
except AttributeError:
diff --git a/python/mlc/parser/__init__.py b/python/mlc/parser/__init__.py
new file mode 100644
index 00000000..2ecaa846
--- /dev/null
+++ b/python/mlc/parser/__init__.py
@@ -0,0 +1,3 @@
+from .diagnostic import DiagnosticError
+from .env import Env, Span, check_decorator
+from .parser import Frame, Parser
diff --git a/python/mlc/parser/diagnostic.py b/python/mlc/parser/diagnostic.py
new file mode 100644
index 00000000..42742d02
--- /dev/null
+++ b/python/mlc/parser/diagnostic.py
@@ -0,0 +1,116 @@
+from __future__ import annotations
+
+import os
+import sys
+from types import TracebackType
+from typing import Any, Literal
+
+import colorama # type: ignore[import-untyped]
+
+from .env import Span
+
+colorama.init(autoreset=True)
+
+
+class DiagnosticError(Exception):
+ pass
+
+
+def raise_diagnostic_error(
+ source: str,
+ source_name: str,
+ span: Span,
+ err: Exception | str,
+) -> None:
+ diag_err = DiagnosticError("Diagnostics were emitted, please check rendered error message.")
+ if isinstance(err, Exception):
+ msg = type(err).__name__ + ": " + [i for i in str(err).split("\n") if i][-1]
+ diag_err.with_traceback(err.__traceback__)
+ else:
+ msg = str(err)
+ print(
+ _render_at(
+ source=source,
+ source_name=source_name,
+ span=span,
+ message=msg,
+ level="error",
+ ),
+ file=sys.stderr,
+ )
+ if isinstance(err, Exception):
+ raise diag_err from err
+ else:
+ raise diag_err
+
+
+def _render_at(
+ source: str,
+ source_name: str,
+ span: Span,
+ message: str,
+ level: Literal["warning", "error", "bug", "note"] = "error",
+) -> str:
+ lines = source.splitlines()
+ row_st = max(1, span.row_st)
+ row_ed = min(len(lines), span.row_ed)
+ # If no valid rows, just return the bare message.
+ if row_st > row_ed:
+ return message
+ # Map the "level" to a color and label (similar to rang::fg usage in C++).
+ color, diag_type = {
+ "warning": (colorama.Fore.YELLOW, "warning"),
+ "error": (colorama.Fore.RED, "error"),
+ "bug": (colorama.Fore.BLUE, "bug"),
+ "note": (colorama.Fore.RESET, "note"),
+ "help": (colorama.Fore.RESET, "help"),
+ }.get(level, (colorama.Fore.RED, "error"))
+
+ # Prepare lines of output
+ out_lines = [
+ f"{colorama.Style.BRIGHT}{color}{diag_type}{colorama.Style.RESET_ALL}: {message}",
+ f"{colorama.Fore.BLUE} --> {colorama.Style.RESET_ALL}{source_name}:{row_st}:{span.col_st}",
+ ]
+ left_margin_width = len(str(row_ed))
+ for row_idx in range(row_st, row_ed + 1):
+ line_text = lines[row_idx - 1] # zero-based
+ line_label = str(row_idx).rjust(left_margin_width)
+ # Step 1. the actual source line
+ out_lines.append(f"{line_label} | {line_text}")
+ # Step 2. the marker line:
+ marker = [" "] * len(line_text)
+ # For the first line...
+ if row_idx == row_st and row_idx == row_ed:
+ # Case 1. Single-line: highlight col_st..col_ed
+ c_start = max(1, span.col_st)
+ c_end = min(len(line_text), span.col_ed)
+ marker[c_start:c_end] = "^" * (c_end - c_start)
+ elif row_idx == row_st:
+ # Case 2. The first line in a multi-line highlight
+ c_start = max(1, span.col_st)
+ marker[c_start:] = "^" * (len(line_text) - c_start)
+ elif row_idx == row_ed:
+ # Case 3. The last line in a multi-line highlight
+ c_end = min(len(line_text), span.col_ed)
+ marker[:c_end] = "^" * c_end
+ else:
+ # Case 4. A line in the middle of row_st..row_ed => highlight entire line
+ marker = ["^"] * len(line_text)
+ out_lines.append(f"{' ' * (left_margin_width)} | {''.join(marker)}")
+ return "\n".join(out_lines)
+
+
+def excepthook(
+ exctype: type[BaseException],
+ value: BaseException,
+ traceback: TracebackType | None,
+) -> Any:
+ should_hide_backtrace = os.environ.get("MLC_BACKTRACE", None) is None
+ if exctype is DiagnosticError and should_hide_backtrace:
+ print("note: run with `MLC_BACKTRACE=1` environment variable to display a backtrace.")
+ return
+ sys_excepthook(exctype, value, traceback)
+
+
+sys_excepthook = sys.excepthook
+sys.excepthook = excepthook
diff --git a/python/mlc/ast/inspection.py b/python/mlc/parser/env.py
similarity index 60%
rename from python/mlc/ast/inspection.py
rename to python/mlc/parser/env.py
index cb1ddb39..f273f182 100644
--- a/python/mlc/ast/inspection.py
+++ b/python/mlc/parser/env.py
@@ -4,14 +4,17 @@
import dataclasses
import inspect
from collections.abc import Callable, Generator
-from typing import Any
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+ import ast
PY_GETFILE = inspect.getfile
PY_FINDSOURCE = inspect.findsource
-@dataclasses.dataclass(init=False)
-class InspectResult:
+@dataclasses.dataclass
+class Env:
source_name: str
source_start_line: int
source_start_column: int
@@ -20,65 +23,89 @@ class InspectResult:
captured: dict[str, Any]
annotations: dict[str, dict[str, Any]]
- def is_defined_in_class(
- self,
- frames: list[inspect.FrameInfo], # obtained via `inspect.stack()`
- *,
- frame_offset: int = 2,
- is_decorator: Callable[[str], bool] = lambda line: line.startswith("@"),
- ) -> bool:
- # Step 1. Inspect `frames[frame_offset]`
- try:
- lineno = frames[frame_offset].lineno
- line = frames[frame_offset].code_context[0].strip() # type: ignore
- except:
- return False
- # Step 2. Determine by the line itself
- if is_decorator(line):
- return True
- if not line.startswith("class"):
- return False
- # Step 3. Determine by its decorators
- source_lines = self.source_full.splitlines(keepends=True)
- lineno_offset = 2
- try:
- source_line = source_lines[lineno - lineno_offset]
- except IndexError:
- return False
- return is_decorator(source_line.strip())
-
-
-def inspect_program(program: Callable | type) -> InspectResult:
- ret = InspectResult()
- source = inspect_source(program)
- ret.source_name, ret.source_start_line, ret.source_start_column, ret.source, ret.source_full = (
- source.source_name,
- source.source_start_line,
- source.source_start_column,
- source.source,
- source.source_full,
- )
- if inspect.isfunction(program):
- ret.captured = inspect_capture_function(program)
- ret.annotations = inspect_annotations_function(program)
- elif inspect.isclass(program):
- ret.captured = inspect_capture_class(program)
- ret.annotations = inspect_annotations_class(program)
- else:
- raise TypeError(f"{program!r} is not a function or class")
- return ret
+ @staticmethod
+ def from_class(program: type) -> Env:
+ return Env(
+ **_inspect_source(program), # type: ignore[arg-type]
+ captured=_inspect_capture_class(program),
+ annotations=_inspect_annotations_class(program),
+ )
+
+ @staticmethod
+ def from_function(program: Callable) -> Env:
+ return Env(
+ **_inspect_source(program), # type: ignore[arg-type]
+ captured=_inspect_capture_function(program),
+ annotations=_inspect_annotations_function(program),
+ )
+
+
+@dataclasses.dataclass
+class Span:
+ row_st: int
+ row_ed: int
+ col_st: int
+ col_ed: int
+
+ @staticmethod
+ def from_ast(node: ast.AST, env: Env | None = None) -> Span:
+ row_st: int = getattr(node, "lineno", None) or 1
+ row_ed: int = getattr(node, "end_lineno", None) or row_st
+ col_st: int = getattr(node, "col_offset", None) or 1
+ col_ed: int = getattr(node, "end_col_offset", None) or col_st
+ if env is not None:
+ row_st += env.source_start_line - 1
+ row_ed += env.source_start_line - 1
+ col_st += env.source_start_column
+ col_ed += env.source_start_column
+ return Span(
+ row_st=row_st,
+ row_ed=row_ed,
+ col_st=col_st,
+ col_ed=col_ed,
+ )
+
+
+def check_decorator(
+ frames: list[inspect.FrameInfo], # obtained via `inspect.stack()`
+ *,
+ source_full: str | None = None,
+ frame_offset: int = 2,
+ checker: Callable[[str], bool] = lambda line: line.startswith("@"),
+) -> bool:
+ # Step 1. Inspect `frames[frame_offset]`
+ try:
+ lineno = frames[frame_offset].lineno
+ line = frames[frame_offset].code_context[0] # type: ignore
+ except:
+ return False
+ # Step 2. Determine by the line itself
+ if checker(line):
+ return True
+ if not line.startswith("class"):
+ return False
+ # Step 3. Determine by its decorators
+ if source_full is None:
+ return False
+ source_lines = source_full.splitlines(keepends=True)
+ lineno_offset = 2
+ try:
+ source_line = source_lines[lineno - lineno_offset]
+ except IndexError:
+ return False
+ return checker(source_line)
@contextlib.contextmanager
-def override_getfile() -> Generator[None, Any, None]:
+def _override_getfile() -> Generator[None, Any, None]:
try:
- inspect.getfile = getfile # type: ignore[assignment]
+ inspect.getfile = _getfile # type: ignore[assignment]
yield
finally:
inspect.getfile = PY_GETFILE # type: ignore[assignment]
-def getfile(obj: Any) -> str:
+def _getfile(obj: Any) -> str:
if not inspect.isclass(obj):
return PY_GETFILE(obj)
mod = getattr(obj, "__module__", None)
@@ -92,16 +119,16 @@ def getfile(obj: Any) -> str:
if inspect.isfunction(member):
if obj.__qualname__ + "." + member.__name__ == member.__qualname__:
return inspect.getfile(member)
- raise TypeError(f"Source for {obj:!r} not found")
+ raise TypeError(f"Source for {obj} not found")
-def getsourcelines(obj: Any) -> tuple[list[str], int]:
+def _getsourcelines(obj: Any) -> tuple[list[str], int]:
obj = inspect.unwrap(obj)
- lines, l_num = findsource(obj)
+ lines, l_num = _findsource(obj)
return inspect.getblock(lines[l_num:]), l_num + 1
-def findsource(obj: Any) -> tuple[list[str], int]: # noqa: PLR0912
+def _findsource(obj: Any) -> tuple[list[str], int]: # noqa: PLR0912
if not inspect.isclass(obj):
return PY_FINDSOURCE(obj)
@@ -154,19 +181,10 @@ def findsource(obj: Any) -> tuple[list[str], int]: # noqa: PLR0912
raise OSError("could not find class definition")
-@dataclasses.dataclass
-class Source:
- source_name: str
- source_start_line: int
- source_start_column: int
- source: str
- source_full: str
-
-
-def inspect_source(program: Callable | type) -> Source:
- with override_getfile():
+def _inspect_source(program: Callable | type) -> dict[str, int | str]:
+ with _override_getfile():
source_name: str = inspect.getsourcefile(program) # type: ignore
- lines, source_start_line = getsourcelines(program) # type: ignore
+ lines, source_start_line = _getsourcelines(program) # type: ignore
if lines:
source_start_column = len(lines[0]) - len(lines[0].lstrip())
else:
@@ -190,20 +208,20 @@ def inspect_source(program: Callable | type) -> Source:
# as a fallback method.
src, _ = inspect.findsource(program) # type: ignore
source_full = "".join(src)
- return Source(
- source_name=source_name,
- source_start_line=source_start_line,
- source_start_column=source_start_column,
- source=source,
- source_full=source_full,
- )
+ return dict(
+ source_name=source_name,
+ source_start_line=source_start_line,
+ source_start_column=source_start_column,
+ source=source,
+ source_full=source_full,
+ )
-def inspect_annotations_function(program: Callable | type) -> dict[str, dict[str, Any]]:
+def _inspect_annotations_function(program: Callable | type) -> dict[str, dict[str, Any]]:
return {program.__name__: program.__annotations__}
-def inspect_annotations_class(program: Callable | type) -> dict[str, dict[str, Any]]:
+def _inspect_annotations_class(program: Callable | type) -> dict[str, dict[str, Any]]:
annotations = {}
for name, func in program.__dict__.items():
if inspect.isfunction(func):
@@ -211,13 +229,13 @@ def inspect_annotations_class(program: Callable | type) -> dict[str, dict[str, A
return annotations
-def inspect_capture_function(func: Callable) -> dict[str, Any]:
+def _inspect_capture_function(func: Callable) -> dict[str, Any]:
def _getclosurevars(func: Callable) -> dict[str, Any]:
# Mofiied from `inspect.getclosurevars`
if inspect.ismethod(func):
func = func.__func__
if not inspect.isfunction(func):
- raise TypeError(f"{func!r} is not a Python function")
+ raise TypeError(f"Not a Python function: {func}")
code = func.__code__
# Nonlocal references are named in co_freevars and resolved
# by looking them up in __closure__ by positional index
@@ -238,10 +256,10 @@ def _getclosurevars(func: Callable) -> dict[str, Any]:
}
-def inspect_capture_class(cls: type) -> dict[str, Any]:
+def _inspect_capture_class(cls: type) -> dict[str, Any]:
result: dict[str, Any] = {}
for _, v in cls.__dict__.items():
if inspect.isfunction(v):
- func_vars = inspect_capture_function(v)
+ func_vars = _inspect_capture_function(v)
result.update(**func_vars)
return result
diff --git a/python/mlc/parser/parser.py b/python/mlc/parser/parser.py
new file mode 100644
index 00000000..944b716d
--- /dev/null
+++ b/python/mlc/parser/parser.py
@@ -0,0 +1,220 @@
+from __future__ import annotations
+
+import ast
+import builtins
+import copy
+from collections.abc import Generator
+from contextlib import contextmanager
+from typing import Any
+
+from .diagnostic import DiagnosticError, raise_diagnostic_error
+from .env import Env, Span
+
+Value = Any
+
+
+class Frame:
+ name2value: dict[str, Value]
+
+ def __init__(self) -> None:
+ self.name2value = {}
+
+ def add(self, name: str, value: Value, override: bool = False) -> None:
+ if name in self.name2value and not override:
+ raise ValueError(f"Variable already defined in current scope: {name}")
+ self.name2value[name] = value
+
+ @contextmanager
+ def scope(self, parser: Parser) -> Generator[Frame, None, None]:
+ parser.frames.append(self)
+ try:
+ yield self
+ finally:
+ parser.frames.pop()
+
+
+class Parser:
+ env: Env
+ frames: list[Frame]
+
+ def __init__(
+ self,
+ env: Env,
+ extra_vars: dict[str, Value] | None = None,
+ include_builtins: bool = True,
+ ) -> None:
+ self.env = env
+ self.frames = [Frame()]
+ if include_builtins:
+ for name, value in builtins.__dict__.items():
+ self.var_def(name, value)
+ if extra_vars is not None:
+ for name, value in extra_vars.items():
+ self.var_def(name, value)
+
+ def var_def(
+ self,
+ name: str,
+ value: Value,
+ frame: Frame | None = None,
+ override: bool = False,
+ ) -> None:
+ if frame is None:
+ frame = self.frames[-1]
+ frame.add(name, value, override=override)
+
+ def report_error(self, node: ast.AST | Span, err: Exception | str) -> None:
+ if isinstance(err, DiagnosticError):
+ raise err
+ span = Span.from_ast(node, self.env) if isinstance(node, ast.AST) else node
+ raise_diagnostic_error(
+ source=self.env.source_full,
+ source_name=self.env.source_name,
+ span=span,
+ err=err,
+ )
+
+ def eval_expr(self, node: ast.expr) -> Value:
+ var_tab: dict[str, Value] = {}
+ for frame in self.frames:
+ var_tab.update({name: value for name, value in frame.name2value.items()})
+ return ExprEvaluator(parser=self, var_tab=var_tab).run(node)
+
+ def eval_assign(self, target: ast.expr, source: Value) -> dict[str, Value]:
+ var_name = "__mlc_rhs_var__"
+ dict_locals: dict[str, Value] = {var_name: source}
+ mod = ast.fix_missing_locations(
+ ast.Module(
+ body=[
+ ast.Assign(
+ targets=[target],
+ value=ast.Name(id=var_name, ctx=ast.Load()),
+ )
+ ],
+ type_ignores=[],
+ )
+ )
+ exe = compile(mod, filename="", mode="exec")
+ exec(exe, {}, dict_locals) # pylint: disable=exec-used
+ del dict_locals[var_name]
+ return dict_locals
+
+
+class ExprEvaluator(ast.NodeTransformer):
+ parser: Parser
+ var_tab: dict[str, Value]
+ num_tmp_vars: int
+
+ def __init__(self, parser: Parser, var_tab: dict[str, Value]) -> None:
+ self.parser = parser
+ self.var_tab = var_tab
+ self.num_tmp_vars = 0
+
+ def run(self, node: ast.AST) -> Value:
+ var_name = "_mlc_evaluator_tmp_var_result"
+ self._eval_subexpr(
+ expr=self.visit(node),
+ target=ast.Name(id=var_name, ctx=ast.Store()),
+ )
+ return self.var_tab[var_name]
+
+ def _make_intermediate_var(self) -> ast.Name:
+ var_name = f"_mlc_evaluator_tmp_var_{self.num_tmp_vars}"
+ self.num_tmp_vars += 1
+ return ast.Name(id=var_name, ctx=ast.Store())
+
+ def visit_Name(self, node: ast.Name) -> ast.Name:
+ if node.id not in self.var_tab:
+ self.parser.report_error(node, f"Undefined variable: {node.id}")
+ return node
+
+ def visit_Call(self, node: ast.Call) -> ast.Name:
+ node.func = self.visit(node.func)
+ node.args = [self.visit(arg) for arg in node.args]
+ node.keywords = [self.visit(kw) for kw in node.keywords]
+ return self._eval_subexpr(
+ expr=node,
+ target=self._make_intermediate_var(),
+ )
+
+ def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.Name:
+ node.operand = self.visit(node.operand)
+ return self._eval_subexpr(
+ expr=node,
+ target=self._make_intermediate_var(),
+ )
+
+ def visit_BinOp(self, node: ast.BinOp) -> ast.Name:
+ node.left = self.visit(node.left)
+ node.right = self.visit(node.right)
+ return self._eval_subexpr(
+ expr=node,
+ target=self._make_intermediate_var(),
+ )
+
+ def visit_Compare(self, node: ast.Compare) -> ast.Name:
+ node.left = self.visit(node.left)
+ node.comparators = [self.visit(comp) for comp in node.comparators]
+ return self._eval_subexpr(
+ expr=node,
+ target=self._make_intermediate_var(),
+ )
+
+ def visit_BoolOp(self, node: ast.BoolOp) -> ast.Name:
+ node.values = [self.visit(value) for value in node.values]
+ return self._eval_subexpr(
+ expr=node,
+ target=self._make_intermediate_var(),
+ )
+
+ def visit_IfExp(self, node: ast.IfExp) -> ast.Name:
+ node.test = self.visit(node.test)
+ node.body = self.visit(node.body)
+ node.orelse = self.visit(node.orelse)
+ return self._eval_subexpr(
+ expr=node,
+ target=self._make_intermediate_var(),
+ )
+
+ def visit_Attribute(self, node: ast.Attribute) -> ast.Name:
+ node.value = self.visit(node.value)
+ return self._eval_subexpr(
+ expr=node,
+ target=self._make_intermediate_var(),
+ )
+
+ def visit_Subscript(self, node: ast.Subscript) -> ast.Name:
+ node.value = self.visit(node.value)
+ node.slice = self.visit(node.slice)
+ return self._eval_subexpr(
+ expr=node,
+ target=self._make_intermediate_var(),
+ )
+
+ def visit_Slice(self, node: ast.Slice) -> ast.Name:
+ if node.lower is not None:
+ node.lower = self.visit(node.lower)
+ if node.upper is not None:
+ node.upper = self.visit(node.upper)
+ if node.step is not None:
+ node.step = self.visit(node.step)
+ return self._eval_subexpr(
+ expr=node,
+ target=self._make_intermediate_var(),
+ )
+
+ def _eval_subexpr(self, expr: ast.expr, target: ast.Name) -> ast.Name:
+ target.ctx = ast.Store()
+ mod = ast.fix_missing_locations(
+ ast.Module(
+ body=[ast.Assign(targets=[target], value=copy.copy(expr))], # target = node
+ type_ignores=[],
+ )
+ )
+ exe = compile(mod, filename="", mode="exec")
+ try:
+ exec(exe, {}, self.var_tab)
+ except Exception as err:
+ self.parser.report_error(node=expr, err=err)
+ target.ctx = ast.Load()
+ return target
diff --git a/python/mlc/printer/ir_printer.py b/python/mlc/printer/ir_printer.py
index da9603ab..37cbe6db 100644
--- a/python/mlc/printer/ir_printer.py
+++ b/python/mlc/printer/ir_printer.py
@@ -5,7 +5,7 @@
import mlc.dataclasses as mlcd
from mlc.core import Func, Object, ObjectPath
-from .ast import Id, Node, PrinterConfig, Stmt
+from .ast import Expr, Node, PrinterConfig, Stmt
from .cprint import cprint
@@ -44,16 +44,26 @@ def __init__(self, cfg: Optional[PrinterConfig] = None) -> None:
def var_is_defined(self, obj: Any) -> bool:
return bool(IRPrinter._C(b"var_is_defined", self, obj))
- def var_def(self, obj: Any, frame: Any, name: str) -> Id:
- return IRPrinter._C(b"var_def", self, obj, frame, name)
-
- def var_def_no_name(self, obj: Any, creator: Func) -> None:
- IRPrinter._C(b"var_def_no_name", self, obj, creator)
+ def var_def(
+ self,
+ name: str,
+ obj: Any,
+ frame: Optional[Any] = None,
+ ) -> None:
+ return IRPrinter._C(b"var_def", self, name, obj, frame)
+
+ def var_def_no_name(
+ self,
+ creator: Func,
+ obj: Any,
+ frame: Optional[Any] = None,
+ ) -> None:
+ IRPrinter._C(b"var_def_no_name", self, creator, obj, frame)
def var_remove(self, obj: Any) -> None:
IRPrinter._C(b"var_remove", self, obj)
- def var_get(self, obj: Any) -> Optional[Id]:
+ def var_get(self, obj: Any) -> Optional[Expr]:
return IRPrinter._C(b"var_get", self, obj)
def frame_push(self, frame: Any) -> None:
diff --git a/python/mlc/testing/__init__.py b/python/mlc/testing/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/python/mlc/testing/toy_ir/__init__.py b/python/mlc/testing/toy_ir/__init__.py
new file mode 100644
index 00000000..61206a55
--- /dev/null
+++ b/python/mlc/testing/toy_ir/__init__.py
@@ -0,0 +1,3 @@
+from .ir import Add, Assign, Expr, Func, Node, Stmt, Var
+from .ir_builder import FunctionFrame, IRBuilder
+from .parser import Parser, parse_func
diff --git a/python/mlc/testing/toy_ir/ir.py b/python/mlc/testing/toy_ir/ir.py
new file mode 100644
index 00000000..5f2e7ad1
--- /dev/null
+++ b/python/mlc/testing/toy_ir/ir.py
@@ -0,0 +1,82 @@
+from __future__ import annotations
+
+import mlc.dataclasses as mlcd
+import mlc.printer as mlcp
+import mlc.printer.ast as mlt
+
+
+@mlcd.py_class
+class Node(mlcd.PyClass): ...
+
+
+@mlcd.py_class
+class Expr(Node): ...
+
+
+@mlcd.py_class
+class Stmt(Node): ...
+
+
+@mlcd.py_class(structure="var")
+class Var(Expr):
+ name: str = mlcd.field(structure=None)
+
+ def __add__(self, other: Var) -> Add:
+ return Add(lhs=self, rhs=other)
+
+ def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
+ if not printer.var_is_defined(obj=self):
+ printer.var_def(self.name, obj=self)
+ ret = printer.var_get(obj=self)
+ assert ret is not None
+ return ret
+
+
+@mlcd.py_class(structure="nobind")
+class Add(Expr):
+ lhs: Expr
+ rhs: Expr
+
+ def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
+ lhs: mlt.Expr = printer(self.lhs, path=path["a"])
+ rhs: mlt.Expr = printer(self.rhs, path=path["b"])
+ return lhs + rhs
+
+
+@mlcd.py_class(structure="bind")
+class Assign(Stmt):
+ rhs: Expr
+ lhs: Var = mlcd.field(structure="bind")
+
+ def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
+ rhs: mlt.Expr = printer(self.rhs, path=path["b"])
+ printer.var_def(self.lhs.name, obj=self.lhs)
+ lhs: mlt.Expr = printer(self.lhs, path=path["a"])
+ return mlt.Assign(lhs=lhs, rhs=rhs)
+
+
+@mlcd.py_class(structure="bind")
+class Func(Node):
+ name: str = mlcd.field(structure=None)
+ args: list[Var] = mlcd.field(structure="bind")
+ stmts: list[Stmt]
+ ret: Var
+
+ def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
+ with printer.with_frame(mlcp.DefaultFrame()):
+ for arg in self.args:
+ printer.var_def(arg.name, obj=arg)
+ args: list[mlt.Expr] = [
+ printer(arg, path=path["args"][i]) for i, arg in enumerate(self.args)
+ ]
+ stmts: list[mlt.Expr] = [
+ printer(stmt, path=path["stmts"][i]) for i, stmt in enumerate(self.stmts)
+ ]
+ ret_stmt = mlt.Return(printer(self.ret, path=path["ret"]))
+ return mlt.Function(
+ name=mlt.Id(self.name),
+ args=[mlt.Assign(lhs=arg, rhs=None) for arg in args],
+ decorators=[],
+ return_type=None,
+ body=[*stmts, ret_stmt],
+ )
diff --git a/python/mlc/testing/toy_ir/ir_builder.py b/python/mlc/testing/toy_ir/ir_builder.py
new file mode 100644
index 00000000..797f2a76
--- /dev/null
+++ b/python/mlc/testing/toy_ir/ir_builder.py
@@ -0,0 +1,69 @@
+from __future__ import annotations
+
+from types import TracebackType
+from typing import Any, ClassVar
+
+from .ir import Func, Stmt, Var
+
+
+class IRBuilder:
+ _ctx: ClassVar[list[IRBuilder]] = []
+ frames: list[Any]
+ result: Any
+
+ def __init__(self) -> None:
+ self.frames = []
+ self.result = None
+
+ def __enter__(self) -> IRBuilder:
+ IRBuilder._ctx.append(self)
+ return self
+
+ def __exit__(
+ self,
+ exc_type: type[BaseException],
+ exc_value: BaseException,
+ traceback: TracebackType,
+ ) -> None:
+ IRBuilder._ctx.pop()
+
+ @staticmethod
+ def get() -> IRBuilder:
+ return IRBuilder._ctx[-1]
+
+
+class FunctionFrame:
+ name: str
+ args: list[Var]
+ stmts: list[Stmt]
+ ret: Var | None
+
+ def __init__(self, name: str) -> None:
+ self.name = name
+ self.args = []
+ self.stmts = []
+ self.ret = None
+
+ def add_arg(self, arg: Var) -> Var:
+ self.args.append(arg)
+ return arg
+
+ def __enter__(self) -> FunctionFrame:
+ IRBuilder.get().frames.append(self)
+ return self
+
+ def __exit__(
+ self,
+ exc_type: type[BaseException],
+ exc_value: BaseException,
+ traceback: TracebackType,
+ ) -> None:
+ frame = IRBuilder.get().frames.pop()
+ assert frame is self
+ if exc_type is None:
+ IRBuilder.get().result = Func(
+ name=self.name,
+ args=frame.args,
+ stmts=frame.stmts,
+ ret=frame.ret,
+ )
diff --git a/python/mlc/testing/toy_ir/parser.py b/python/mlc/testing/toy_ir/parser.py
new file mode 100644
index 00000000..5e9012eb
--- /dev/null
+++ b/python/mlc/testing/toy_ir/parser.py
@@ -0,0 +1,66 @@
+from __future__ import annotations
+
+import ast
+from typing import Callable
+
+import mlc.parser as mlcs
+
+from .ir import Assign, Node, Var
+from .ir_builder import FunctionFrame, IRBuilder
+
+
+class Parser(ast.NodeVisitor):
+ base: mlcs.Parser
+
+ def __init__(self, env: mlcs.Env) -> None:
+ self.base = mlcs.Parser(env, include_builtins=True, extra_vars=None)
+
+ def visit_Assign(self, node: ast.Assign) -> None:
+ if len(node.targets) != 1:
+ self.base.report_error(node, "Continuous assignment is not supported")
+ (target,) = node.targets
+ if not isinstance(target, ast.Name):
+ self.base.report_error(target, "Invalid assignment target")
+ assert isinstance(target, ast.Name)
+ value = self.base.eval_assign(
+ target=target,
+ source=self.base.eval_expr(node.value),
+ )[target.id]
+ var = Var(name=target.id)
+ self.base.var_def(name=target.id, value=var)
+ IRBuilder.get().frames[-1].stmts.append(
+ Assign(
+ lhs=var,
+ rhs=value,
+ )
+ )
+
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
+ with mlcs.Frame().scope(self.base):
+ with FunctionFrame(node.name) as frame:
+ for node_arg in node.args.args:
+ self.base.var_def(
+ name=node_arg.arg,
+ value=frame.add_arg(Var(name=node_arg.arg)),
+ )
+ for node_stmt in node.body:
+ self.visit(node_stmt)
+
+ def visit_Return(self, node: ast.Return) -> None:
+ if not isinstance(node.value, ast.Name):
+ self.base.report_error(node, "Return statement must return a single variable")
+ assert isinstance(node.value, ast.Name)
+ frame: FunctionFrame = IRBuilder.get().frames[-1]
+ frame.ret = self.base.eval_expr(node.value)
+
+
+def parse_func(source: Callable) -> Node:
+ env = mlcs.Env.from_function(source)
+ parser = Parser(env)
+ node = ast.parse(
+ env.source,
+ filename=env.source_name,
+ )
+ with IRBuilder() as ib:
+ parser.visit(node)
+ return ib.result
diff --git a/tests/cpp/test_base_optional.cc b/tests/cpp/test_base_optional.cc
index df855185..c166b762 100644
--- a/tests/cpp/test_base_optional.cc
+++ b/tests/cpp/test_base_optional.cc
@@ -274,7 +274,7 @@ TEST(OptionalAnyConversion, ObjectRefType) {
// Construct from AnyView Tests
TEST(OptionalConstructFromAnyView, IntType) {
AnyView view(42);
- Optional opt_int(view);
+ Optional opt_int = view.operator Optional(); // TODO: recover implicit casting
EXPECT_TRUE(opt_int.defined());
EXPECT_EQ(*opt_int, 42);
}
@@ -282,7 +282,7 @@ TEST(OptionalConstructFromAnyView, IntType) {
TEST(OptionalConstructFromAnyView, ObjectRefType) {
TestObjRef obj(Ref::New(10));
AnyView obj_view(obj);
- Optional opt_obj(obj_view);
+ Optional opt_obj = obj_view.operator Optional(); // TODO: recover implicit casting
EXPECT_TRUE(opt_obj.defined());
EXPECT_EQ(opt_obj->value, 10);
}
@@ -290,7 +290,7 @@ TEST(OptionalConstructFromAnyView, ObjectRefType) {
// Construct from Any Tests
TEST(OptionalConstructFromAny, IntType) {
Any any(42);
- Optional opt_int(any);
+ Optional opt_int = any.operator Optional(); // TODO: recover implicit casting
EXPECT_TRUE(opt_int.defined());
EXPECT_EQ(*opt_int, 42);
}
@@ -298,7 +298,7 @@ TEST(OptionalConstructFromAny, IntType) {
TEST(OptionalConstructFromAny, ObjectRefType) {
TestObjRef obj(Ref::New(10));
Any obj_any(obj);
- Optional opt_obj(obj_any);
+ Optional opt_obj = obj_any.operator Optional(); // TODO: recover implicit casting
EXPECT_TRUE(opt_obj.defined());
EXPECT_EQ(opt_obj->value, 10);
}
diff --git a/tests/cpp/test_base_ref.cc b/tests/cpp/test_base_ref.cc
index 36e425ee..4205c086 100644
--- a/tests/cpp/test_base_ref.cc
+++ b/tests/cpp/test_base_ref.cc
@@ -303,7 +303,7 @@ TEST(RefPOD, ConversionToAny) {
Any any = ref;
ref.Reset();
EXPECT_EQ(any.operator int64_t(), 42);
- ref = Ref(any);
+ ref = any.operator Ref();
EXPECT_EQ(*ref, 42);
EXPECT_EQ(GetRefCount(ref), 1);
}
@@ -319,7 +319,7 @@ TEST(RefPOD, ConversionToAnyView) {
AnyView any_view = ref;
ref.Reset();
EXPECT_EQ(any_view.operator int64_t(), 42);
- ref = Ref(any_view);
+ ref = any_view.operator Ref();
EXPECT_EQ(*ref, 42);
EXPECT_EQ(GetRefCount(ref), 1);
}
diff --git a/tests/python/test_parser_toy_ir_parser.py b/tests/python/test_parser_toy_ir_parser.py
new file mode 100644
index 00000000..3c502a15
--- /dev/null
+++ b/tests/python/test_parser_toy_ir_parser.py
@@ -0,0 +1,28 @@
+from __future__ import annotations
+
+from mlc.testing import toy_ir
+from mlc.testing.toy_ir import Add, Assign, Func, Var
+
+
+def test_parse_func() -> None:
+ def source_code(a, b, c): # noqa: ANN001, ANN202
+ d = a + b
+ e = d + c
+ return e
+
+ def _expected() -> Func:
+ a = Var(name="_a")
+ b = Var(name="_b")
+ c = Var(name="_c")
+ d = Var(name="_d")
+ e = Var(name="_e")
+ stmts = [
+ Assign(lhs=d, rhs=Add(a, b)),
+ Assign(lhs=e, rhs=Add(d, c)),
+ ]
+ f = Func(name="_f", args=[a, b, c], stmts=stmts, ret=e)
+ return f
+
+ result = toy_ir.parse_func(source_code)
+ expected = _expected()
+ result.eq_s(expected, assert_mode=True)
diff --git a/tests/python/test_printer_ir_printer.py b/tests/python/test_printer_ir_printer.py
index 61f04bb8..d7f8057b 100644
--- a/tests/python/test_printer_ir_printer.py
+++ b/tests/python/test_printer_ir_printer.py
@@ -1,76 +1,5 @@
-import mlc.dataclasses as mlcd
import mlc.printer as mlcp
-from mlc.printer import ast as mlt
-
-
-@mlcd.py_class
-class Expr(mlcd.PyClass): ...
-
-
-@mlcd.py_class
-class Stmt(mlcd.PyClass): ...
-
-
-@mlcd.py_class
-class Var(Expr):
- name: str
-
- def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
- if not printer.var_is_defined(obj=self):
- printer.var_def(obj=self, frame=printer.frames[-1], name=self.name)
- ret = printer.var_get(obj=self)
- assert isinstance(ret, mlt.Id)
- return ret
-
-
-@mlcd.py_class
-class Add(Expr):
- lhs: Expr
- rhs: Expr
-
- def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
- lhs: mlt.Expr = printer(obj=self.lhs, path=path["a"])
- rhs: mlt.Expr = printer(obj=self.rhs, path=path["b"])
- return lhs + rhs
-
-
-@mlcd.py_class
-class Assign(Stmt):
- lhs: Var
- rhs: Expr
-
- def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
- rhs: mlt.Expr = printer(obj=self.rhs, path=path["b"])
- printer.var_def(obj=self.lhs, frame=printer.frames[-1], name=self.lhs.name)
- lhs: mlt.Expr = printer(obj=self.lhs, path=path["a"])
- return mlt.Assign(lhs=lhs, rhs=rhs)
-
-
-@mlcd.py_class
-class Func(mlcd.PyClass):
- name: str
- args: list[Var]
- stmts: list[Stmt]
- ret: Var
-
- def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
- with printer.with_frame(mlcp.DefaultFrame()):
- for arg in self.args:
- printer.var_def(obj=arg, frame=printer.frames[-1], name=arg.name)
- args: list[mlt.Expr] = [
- printer(obj=arg, path=path["args"][i]) for i, arg in enumerate(self.args)
- ]
- stmts: list[mlt.Expr] = [
- printer(obj=stmt, path=path["stmts"][i]) for i, stmt in enumerate(self.stmts)
- ]
- ret_stmt = mlt.Return(printer(obj=self.ret, path=path["ret"]))
- return mlt.Function(
- name=mlt.Id(self.name),
- args=[mlt.Assign(lhs=arg, rhs=None) for arg in args],
- decorators=[],
- return_type=None,
- body=[*stmts, ret_stmt],
- )
+from mlc.testing.toy_ir import Add, Assign, Func, Var
def test_var_print() -> None: