From c1a179c817578bc3f71e2e590930c713e7c6d2c9 Mon Sep 17 00:00:00 2001 From: Mashed Potato <38517644+potatomashed@users.noreply.github.com> Date: Sat, 28 Dec 2024 10:56:14 -0800 Subject: [PATCH] AST Parser --- CMakeLists.txt | 2 +- README.md | 93 +- include/mlc/core/typing.h | 38 +- include/mlc/core/utils.h | 7 + include/mlc/printer/ir_printer.h | 13 +- pyproject.toml | 5 +- python/mlc/__init__.py | 2 +- python/mlc/_cython/core.pyx | 8 +- python/mlc/ast/__init__.py | 4 - python/mlc/ast/mlc_ast.py | 855 ------------------ python/mlc/ast/translate.py | 109 --- python/mlc/core/typing.py | 34 +- python/mlc/dataclasses/utils.py | 7 +- python/mlc/parser/__init__.py | 3 + python/mlc/parser/diagnostic.py | 116 +++ .../mlc/{ast/inspection.py => parser/env.py} | 182 ++-- python/mlc/parser/parser.py | 220 +++++ python/mlc/printer/ir_printer.py | 24 +- python/mlc/testing/__init__.py | 0 python/mlc/testing/toy_ir/__init__.py | 3 + python/mlc/testing/toy_ir/ir.py | 82 ++ python/mlc/testing/toy_ir/ir_builder.py | 69 ++ python/mlc/testing/toy_ir/parser.py | 66 ++ tests/cpp/test_base_optional.cc | 8 +- tests/cpp/test_base_ref.cc | 4 +- tests/python/test_parser_toy_ir_parser.py | 28 + tests/python/test_printer_ir_printer.py | 73 +- 27 files changed, 793 insertions(+), 1262 deletions(-) delete mode 100644 python/mlc/ast/__init__.py delete mode 100644 python/mlc/ast/mlc_ast.py delete mode 100644 python/mlc/ast/translate.py create mode 100644 python/mlc/parser/__init__.py create mode 100644 python/mlc/parser/diagnostic.py rename python/mlc/{ast/inspection.py => parser/env.py} (60%) create mode 100644 python/mlc/parser/parser.py create mode 100644 python/mlc/testing/__init__.py create mode 100644 python/mlc/testing/toy_ir/__init__.py create mode 100644 python/mlc/testing/toy_ir/ir.py create mode 100644 python/mlc/testing/toy_ir/ir_builder.py create mode 100644 python/mlc/testing/toy_ir/parser.py create mode 100644 tests/python/test_parser_toy_ir_parser.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 8f40d945..961ccd48 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.15) project( mlc - VERSION 0.0.13 + VERSION 0.0.14 DESCRIPTION "MLC-Python" LANGUAGES C CXX ) diff --git a/README.md b/README.md index 2b93b258..a9fdd74f 100644 --- a/README.md +++ b/README.md @@ -121,87 +121,21 @@ ValueError: Structural equality check failed at {root}.rhs.b: Inconsistent bindi ### :snake: Text Formats in Python -**IR Printer.** By defining an `__ir_print__` method, which converts an IR node to MLC's Python-style AST, MLC's `IRPrinter` handles variable scoping, renaming and syntax highlighting automatically for a text format based on Python syntax. +**Printer.** MLC converts an IR node to Python AST by looking up the `__ir_print__` method. -
Defining Python-based text format on a toy IR using `__ir_print__`. +**[[Example](https://github.com/mlc-ai/mlc-python/blob/main/python/mlc/testing/toy_ir/ir.py)]**. Copy the toy IR definition to REPL and then create a `Func` node below: ```python -import mlc.dataclasses as mlcd -import mlc.printer as mlcp -from mlc.printer import ast as mlt - -@mlcd.py_class -class Expr(mlcd.PyClass): ... - -@mlcd.py_class -class Stmt(mlcd.PyClass): ... - -@mlcd.py_class -class Var(Expr): - name: str - def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node: - if not printer.var_is_defined(obj=self): - printer.var_def(obj=self, frame=printer.frames[-1], name=self.name) - return printer.var_get(obj=self) - -@mlcd.py_class -class Add(Expr): - lhs: Expr - rhs: Expr - def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node: - lhs: mlt.Expr = printer(obj=self.lhs, path=path["a"]) - rhs: mlt.Expr = printer(obj=self.rhs, path=path["b"]) - return lhs + rhs - -@mlcd.py_class -class Assign(Stmt): - lhs: Var - rhs: Expr - def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node: - rhs: mlt.Expr = printer(obj=self.rhs, path=path["b"]) - printer.var_def(obj=self.lhs, frame=printer.frames[-1], name=self.lhs.name) - lhs: mlt.Expr = printer(obj=self.lhs, path=path["a"]) - return mlt.Assign(lhs=lhs, rhs=rhs) - -@mlcd.py_class -class Func(mlcd.PyClass): - name: str - args: list[Var] - stmts: list[Stmt] - ret: Var - def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node: - with printer.with_frame(mlcp.DefaultFrame()): - for arg in self.args: - printer.var_def(obj=arg, frame=printer.frames[-1], name=arg.name) - args: list[mlt.Expr] = [printer(obj=arg, path=path["args"][i]) for i, arg in enumerate(self.args)] - stmts: list[mlt.Expr] = [printer(obj=stmt, path=path["stmts"][i]) for i, stmt in enumerate(self.stmts)] - ret_stmt = mlt.Return(printer(obj=self.ret, path=path["ret"])) - return mlt.Function( - name=mlt.Id(self.name), - args=[mlt.Assign(lhs=arg, rhs=None) for arg in args], - decorators=[], - return_type=None, - body=[*stmts, ret_stmt], - ) - -# An example IR: -a, b, c, d, e = Var("a"), Var("b"), Var("c"), Var("d"), Var("e") -f = Func( - name="f", - args=[a, b, c], +>>> a, b, c, d, e = Var("a"), Var("b"), Var("c"), Var("d"), Var("e") +>>> f = Func("f", [a, b, c], stmts=[ Assign(lhs=d, rhs=Add(a, b)), # d = a + b Assign(lhs=e, rhs=Add(d, c)), # e = d + c ], - ret=e, -) + ret=e) ``` -
- -Two printer APIs are provided for Python-based text format: -- `mlc.printer.to_python` that converts an IR fragment to Python text, and -- `mlc.printer.print_python` that further renders the text with proper syntax highlighting. +- Method `mlc.printer.to_python` converts an IR node to Python-based text; ```python >>> print(mlcp.to_python(f)) # Stringify to Python @@ -209,12 +143,25 @@ def f(a, b, c): d = a + b e = d + c return e +``` + +- Method `mlc.printer.print_python` further renders the text with proper syntax highlighting. [[Screenshot](https://raw.githubusercontent.com/gist/potatomashed/5a9b20edbdde1b9a91a360baa6bce9ff/raw/3c68031eaba0620a93add270f8ad7ed2c8724a78/mlc-python-printer.svg)] + +```python >>> mlcp.print_python(f) # Syntax highlighting ``` +**AST Parser.** MLC has a concise set of APIs for implementing parser with Python's AST module, including: +- Inspection API that obtains source code of a Python class or function and the variables they capture; +- Variable management APIs that help with proper scoping; +- AST fragment evaluation APIs; +- Error rendering APIs. + +**[[Example](https://github.com/mlc-ai/mlc-python/blob/main/python/mlc/testing/toy_ir/parser.py)]**. With MLC APIs, a parser can be implemented with 100 lines of code for the Python text format above defined by `__ir_printer__`. + ### :zap: Zero-Copy Interoperability with C++ Plugins -TBD +🚧 Under construction. ## :fuelpump: Development diff --git a/include/mlc/core/typing.h b/include/mlc/core/typing.h index 0021b792..1d8a99da 100644 --- a/include/mlc/core/typing.h +++ b/include/mlc/core/typing.h @@ -122,8 +122,7 @@ struct AtomicType : public Type { }; struct PtrTypeObj : protected MLCTypingPtr { - explicit PtrTypeObj(Type ty) : MLCTypingPtr{} { this->TyMutable() = ty; } - Type Ty() const { return Type(reinterpret_cast &>(this->MLCTypingPtr::ty)); } + explicit PtrTypeObj(Type ty) : MLCTypingPtr{} { this->TyMut() = ty; } ::mlc::Str __str__() const { std::ostringstream os; os << "Ptr[" << this->Ty() << "]"; @@ -138,20 +137,20 @@ struct PtrTypeObj : protected MLCTypingPtr { MLC_DEF_STATIC_TYPE(PtrTypeObj, TypeObj, MLCTypeIndex::kMLCTypingPtr, "mlc.core.typing.PtrType"); private: - Type &TyMutable() { return reinterpret_cast(this->MLCTypingPtr::ty); } + Type &TyMut() { return reinterpret_cast(this->MLCTypingPtr::ty); } + Type Ty() const { return Type(reinterpret_cast &>(this->MLCTypingPtr::ty)); } }; struct PtrType : public Type { MLC_DEF_OBJ_REF(PtrType, PtrTypeObj, Type) .StaticFn("__init__", InitOf) - .MemFn("_ty", &PtrTypeObj::Ty) + ._Field("ty", offsetof(MLCTypingPtr, ty), sizeof(MLCTypingPtr::ty), false, ParseType()) .MemFn("__str__", &PtrTypeObj::__str__) .MemFn("__cxx_str__", &PtrTypeObj::__cxx_str__); }; struct OptionalObj : protected MLCTypingOptional { explicit OptionalObj(Type ty) : MLCTypingOptional{} { this->TyMutable() = ty; } - Type Ty() const { return Type(reinterpret_cast &>(this->MLCTypingOptional::ty)); } ::mlc::Str __str__() const { std::ostringstream os; os << this->Ty() << " | None"; @@ -167,19 +166,19 @@ struct OptionalObj : protected MLCTypingOptional { private: Type &TyMutable() { return reinterpret_cast(this->MLCTypingOptional::ty); } + Type Ty() const { return Type(reinterpret_cast &>(this->MLCTypingOptional::ty)); } }; struct Optional : public Type { MLC_DEF_OBJ_REF(Optional, OptionalObj, Type) .StaticFn("__init__", InitOf) - .MemFn("_ty", &OptionalObj::Ty) + ._Field("ty", offsetof(MLCTypingOptional, ty), sizeof(MLCTypingOptional::ty), false, ParseType()) .MemFn("__str__", &OptionalObj::__str__) .MemFn("__cxx_str__", &OptionalObj::__cxx_str__); }; struct ListObj : protected MLCTypingList { explicit ListObj(Type ty) : MLCTypingList{} { this->TyMutable() = ty; } - Type Ty() const { return Type(reinterpret_cast &>(this->MLCTypingList::ty)); } ::mlc::Str __str__() const { std::ostringstream os; os << "list[" << this->Ty() << "]"; @@ -195,31 +194,30 @@ struct ListObj : protected MLCTypingList { protected: Type &TyMutable() { return reinterpret_cast(this->MLCTypingList::ty); } + Type Ty() const { return Type(reinterpret_cast &>(this->MLCTypingList::ty)); } }; struct List : public Type { MLC_DEF_OBJ_REF(List, ListObj, Type) .StaticFn("__init__", InitOf) - .MemFn("_ty", &ListObj::Ty) + ._Field("ty", offsetof(MLCTypingList, ty), sizeof(MLCTypingList::ty), false, ParseType()) .MemFn("__str__", &ListObj::__str__) .MemFn("__cxx_str__", &ListObj::__cxx_str__); }; struct DictObj : protected MLCTypingDict { explicit DictObj(Type ty_k, Type ty_v) : MLCTypingDict{} { - this->TyMutableK() = ty_k; - this->TyMutableV() = ty_v; + this->TyKMut() = ty_k; + this->TyVMut() = ty_v; } - Type key() const { return Type(reinterpret_cast &>(this->ty_k)); } - Type value() const { return Type(reinterpret_cast &>(this->ty_v)); } ::mlc::Str __str__() const { std::ostringstream os; - os << "dict[" << this->key() << ", " << this->value() << "]"; + os << "dict[" << this->TyK() << ", " << this->TyV() << "]"; return os.str(); } ::mlc::Str __cxx_str__() const { - ::mlc::Str k_str = ::mlc::base::LibState::CxxStr(this->key()); - ::mlc::Str v_str = ::mlc::base::LibState::CxxStr(this->value()); + ::mlc::Str k_str = ::mlc::base::LibState::CxxStr(this->TyK()); + ::mlc::Str v_str = ::mlc::base::LibState::CxxStr(this->TyV()); std::ostringstream os; os << "::mlc::Dict<" << k_str->data() << ", " << v_str->data() << ">"; return os.str(); @@ -227,15 +225,17 @@ struct DictObj : protected MLCTypingDict { MLC_DEF_STATIC_TYPE(DictObj, TypeObj, MLCTypeIndex::kMLCTypingDict, "mlc.core.typing.Dict"); protected: - Type &TyMutableK() { return reinterpret_cast(this->ty_k); } - Type &TyMutableV() { return reinterpret_cast(this->ty_v); } + Type &TyKMut() { return reinterpret_cast(this->ty_k); } + Type &TyVMut() { return reinterpret_cast(this->ty_v); } + Type TyK() const { return Type(reinterpret_cast &>(this->ty_k)); } + Type TyV() const { return Type(reinterpret_cast &>(this->ty_v)); } }; struct Dict : public Type { MLC_DEF_OBJ_REF(Dict, DictObj, Type) .StaticFn("__init__", InitOf) - .MemFn("_key", &DictObj::key) - .MemFn("_value", &DictObj::value) + ._Field("ty_k", offsetof(MLCTypingDict, ty_k), sizeof(MLCTypingDict::ty_k), false, ParseType()) + ._Field("ty_v", offsetof(MLCTypingDict, ty_v), sizeof(MLCTypingDict::ty_v), false, ParseType()) .MemFn("__str__", &DictObj::__str__) .MemFn("__cxx_str__", &DictObj::__cxx_str__); }; diff --git a/include/mlc/core/utils.h b/include/mlc/core/utils.h index 3bf2815e..67f86eab 100644 --- a/include/mlc/core/utils.h +++ b/include/mlc/core/utils.h @@ -128,6 +128,13 @@ struct ReflectionHelper { return *this; } + inline ReflectionHelper &_Field(const char *name, int64_t field_offset, int32_t num_bytes, bool frozen, Any ty) { + this->any_pool.push_back(ty); + int32_t index = static_cast(this->fields.size()); + this->fields.emplace_back(MLCTypeField{name, index, field_offset, num_bytes, frozen, ty.v.v_obj}); + return *this; + } + template inline ReflectionHelper &MemFn(const char *name, Callable &&method) { MLCTypeMethod m = this->PrepareMethod(name, std::forward(method)); m.kind = kMemFn; diff --git a/include/mlc/printer/ir_printer.h b/include/mlc/printer/ir_printer.h index c4cdf52e..83213948 100644 --- a/include/mlc/printer/ir_printer.h +++ b/include/mlc/printer/ir_printer.h @@ -49,7 +49,7 @@ struct IRPrinterObj : public Object { bool VarIsDefined(const ObjectRef &obj) { return obj2info->count(obj) > 0; } - Id VarDef(const ObjectRef &obj, const ObjectRef &frame, Str name_hint) { + Id VarDef(Str name_hint, const ObjectRef &obj, const Optional &frame) { if (auto it = obj2info.find(obj); it != obj2info.end()) { Optional name = (*it).second->name; return Id(name.value()); @@ -66,18 +66,19 @@ struct IRPrinterObj : public Object { name = name_hint.ToStdString() + '_' + std::to_string(i); } defined_names->Set(name, 1); - this->_VarDef(obj, frame, VarInfo(name, Func([name]() { return Id(name); }))); + this->_VarDef(VarInfo(name, Func([name]() { return Id(name); })), obj, frame); return Id(name); } - void VarDefNoName(const ObjectRef &obj, const ObjectRef &frame, const Func &creator) { + void VarDefNoName(const Func &creator, const ObjectRef &obj, const Optional &frame) { if (obj2info.count(obj) > 0) { MLC_THROW(KeyError) << "Variable already defined: " << obj; } - this->_VarDef(obj, frame, VarInfo(mlc::Null, creator)); + this->_VarDef(VarInfo(mlc::Null, creator), obj, frame); } - void _VarDef(const ObjectRef &obj, const ObjectRef &frame, VarInfo var_info) { + void _VarDef(VarInfo var_info, const ObjectRef &obj, const Optional &_frame) { + ObjectRef frame = _frame.defined() ? _frame.value() : this->frames.back().operator ObjectRef(); obj2info->Set(obj, var_info); auto it = frame_vars.find(frame); if (it == frame_vars.end()) { @@ -99,7 +100,7 @@ struct IRPrinterObj : public Object { obj2info.erase(it); } - Optional VarGet(const ObjectRef &obj) { + Optional VarGet(const ObjectRef &obj) { auto it = obj2info.find(obj); if (it == obj2info.end()) { return Null; diff --git a/pyproject.toml b/pyproject.toml index 9dd02b0b..02cabca4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,11 @@ [project] name = "mlc-python" -version = "0.0.13" +version = "0.0.14" dependencies = [ 'numpy >= 1.22', - "ml-dtypes >= 0.1", + 'ml-dtypes >= 0.1', 'Pygments>=2.4.0', + 'colorama', 'setuptools ; platform_system == "Windows"', ] description = "" diff --git a/python/mlc/__init__.py b/python/mlc/__init__.py index 148dd6a7..8c060f2c 100644 --- a/python/mlc/__init__.py +++ b/python/mlc/__init__.py @@ -1,4 +1,4 @@ -from . import _cython, ast, cc, dataclasses, printer +from . import _cython, cc, dataclasses, parser, printer from ._cython import Ptr, Str from .core import DataType, Device, Dict, Error, Func, List, Object, ObjectPath, typing from .dataclasses import PyClass, c_class, py_class diff --git a/python/mlc/_cython/core.pyx b/python/mlc/_cython/core.pyx index 0a1fdd52..93dd60fc 100644 --- a/python/mlc/_cython/core.pyx +++ b/python/mlc/_cython/core.pyx @@ -1360,12 +1360,16 @@ def make_mlc_init(list fields): cdef tuple setters = _setters cdef int32_t num_args = len(args) cdef int32_t i = 0 + cdef object e = None assert num_args == len(setters) while i < num_args: try: setters[i](self, args[i]) - except Exception as e: # no-cython-lint - raise ValueError(f"Failed to set field `{fields[i].name}`: {str(e)}. Got: {args[i]}") + except Exception as _e: # no-cython-lint + e = ValueError(f"Failed to set field `{fields[i].name}`: {str(_e)}. Got: {args[i]}") + e = e.with_traceback(_e.__traceback__) + if e is not None: + raise e i += 1 return _mlc_init diff --git a/python/mlc/ast/__init__.py b/python/mlc/ast/__init__.py deleted file mode 100644 index 402b60b4..00000000 --- a/python/mlc/ast/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from . import translate -from .inspection import InspectResult, inspect_program -from .mlc_ast import * # noqa: F403 -from .translate import translate_mlc_to_py, translate_py_to_mlc diff --git a/python/mlc/ast/mlc_ast.py b/python/mlc/ast/mlc_ast.py deleted file mode 100644 index 72bb7861..00000000 --- a/python/mlc/ast/mlc_ast.py +++ /dev/null @@ -1,855 +0,0 @@ -# Derived from -# https://github.com/python/typeshed/blob/main/stdlib/ast.pyi - -from typing import Any, Optional - -from mlc import dataclasses as mlcd - -_Identifier = str - - -@mlcd.py_class(type_key="mlc.ast.AST", structure="bind") -class AST(mlcd.PyClass): ... - - -@mlcd.py_class(type_key="mlc.ast.mod", structure="bind") -class mod(AST): ... - - -@mlcd.py_class(type_key="mlc.ast.expr_context", structure="bind") -class expr_context(AST): ... - - -@mlcd.py_class(type_key="mlc.ast.operator", structure="bind") -class operator(AST): ... - - -@mlcd.py_class(type_key="mlc.ast.cmpop", structure="bind") -class cmpop(AST): ... - - -@mlcd.py_class(type_key="mlc.ast.unaryop", structure="bind") -class unaryop(AST): ... - - -@mlcd.py_class(type_key="mlc.ast.boolop", structure="bind") -class boolop(AST): ... - - -@mlcd.py_class(type_key="mlc.ast.type_ignore", structure="bind") -class type_ignore(AST): ... - - -@mlcd.py_class(type_key="mlc.ast.stmt", structure="bind") -class stmt(AST): - lineno: Optional[int] - col_offset: Optional[int] - end_lineno: Optional[int] - end_col_offset: Optional[int] - - -@mlcd.py_class(type_key="mlc.ast.expr", structure="bind") -class expr(AST): - lineno: Optional[int] - col_offset: Optional[int] - end_lineno: Optional[int] - end_col_offset: Optional[int] - - -@mlcd.py_class(type_key="mlc.ast.type_param", structure="bind") -class type_param(AST): - lineno: Optional[int] - col_offset: Optional[int] - end_lineno: Optional[int] - end_col_offset: Optional[int] - - -@mlcd.py_class(type_key="mlc.ast.pattern", structure="bind") -class pattern(AST): - lineno: Optional[int] - col_offset: Optional[int] - end_lineno: Optional[int] - end_col_offset: Optional[int] - - -@mlcd.py_class(type_key="mlc.ast.arg", structure="bind") -class arg(AST): - lineno: Optional[int] - col_offset: Optional[int] - end_lineno: Optional[int] - end_col_offset: Optional[int] - arg: _Identifier - annotation: Optional[expr] - type_comment: Optional[str] - - -@mlcd.py_class(type_key="mlc.ast.keyword", structure="bind") -class keyword(AST): - lineno: Optional[int] - col_offset: Optional[int] - end_lineno: Optional[int] - end_col_offset: Optional[int] - arg: Optional[_Identifier] - value: expr - - -@mlcd.py_class(type_key="mlc.ast.alias", structure="bind") -class alias(AST): - lineno: Optional[int] - col_offset: Optional[int] - end_lineno: Optional[int] - end_col_offset: Optional[int] - name: str - asname: Optional[_Identifier] - - -@mlcd.py_class(type_key="mlc.ast.arguments", structure="bind") -class arguments(AST): - posonlyargs: list[arg] - args: list[arg] - vararg: Optional[arg] - kwonlyargs: list[arg] - kw_defaults: list[Optional[expr]] - kwarg: Optional[arg] - defaults: list[expr] - - -@mlcd.py_class(type_key="mlc.ast.comprehension", structure="bind") -class comprehension(AST): - target: expr - iter: expr - ifs: list[expr] - is_async: int - - -@mlcd.py_class(type_key="mlc.ast.withitem", structure="bind") -class withitem(AST): - context_expr: expr - optional_vars: Optional[expr] - - -@mlcd.py_class(type_key="mlc.ast.TypeIgnore", structure="bind") -class TypeIgnore(type_ignore): - lineno: Optional[int] - tag: str - - -@mlcd.py_class(type_key="mlc.ast.Module", structure="bind") -class Module(mod): - body: list[stmt] - type_ignores: list[TypeIgnore] - - -@mlcd.py_class(type_key="mlc.ast.Interactive", structure="bind") -class Interactive(mod): - body: list[stmt] - - -@mlcd.py_class(type_key="mlc.ast.Expression", structure="bind") -class Expression(mod): - body: expr - - -@mlcd.py_class(type_key="mlc.ast.FunctionType", structure="bind") -class FunctionType(mod): - argtypes: list[expr] - returns: expr - - -@mlcd.py_class(type_key="mlc.ast.FunctionDef", structure="bind") -class FunctionDef(stmt): - name: _Identifier - args: arguments - body: list[stmt] - decorator_list: list[expr] - returns: Optional[expr] - type_comment: Optional[str] - type_params: Optional[list[type_param]] - - -@mlcd.py_class(type_key="mlc.ast.AsyncFunctionDef", structure="bind") -class AsyncFunctionDef(stmt): - name: _Identifier - args: arguments - body: list[stmt] - decorator_list: list[expr] - returns: Optional[expr] - type_comment: Optional[str] - type_params: Optional[list[type_param]] - - -@mlcd.py_class(type_key="mlc.ast.ClassDef", structure="bind") -class ClassDef(stmt): - name: _Identifier - bases: list[expr] - keywords: list[keyword] - body: list[stmt] - decorator_list: list[expr] - type_params: Optional[list[type_param]] - - -@mlcd.py_class(type_key="mlc.ast.Return", structure="bind") -class Return(stmt): - value: Optional[expr] - - -@mlcd.py_class(type_key="mlc.ast.Delete", structure="bind") -class Delete(stmt): - targets: list[expr] - - -@mlcd.py_class(type_key="mlc.ast.Assign", structure="bind") -class Assign(stmt): - targets: list[expr] - value: expr - type_comment: Optional[str] - - -@mlcd.py_class(type_key="mlc.ast.Attribute", structure="bind") -class Attribute(expr): - value: expr - attr: _Identifier - ctx: expr_context - - -@mlcd.py_class(type_key="mlc.ast.Subscript", structure="bind") -class Subscript(expr): - value: expr - slice: expr - ctx: expr_context - - -@mlcd.py_class(type_key="mlc.ast.AugAssign", structure="bind") -class AugAssign(stmt): - target: Any # Name | Attribute | Subscript - op: operator - value: expr - - -@mlcd.py_class(type_key="mlc.ast.AnnAssign", structure="bind") -class AnnAssign(stmt): - target: Any # Name | Attribute | Subscript - annotation: expr - value: Optional[expr] - simple: int - - -@mlcd.py_class(type_key="mlc.ast.For", structure="bind") -class For(stmt): - target: expr - iter: expr - body: list[stmt] - orelse: list[stmt] - type_comment: Optional[str] - - -@mlcd.py_class(type_key="mlc.ast.AsyncFor", structure="bind") -class AsyncFor(stmt): - target: expr - iter: expr - body: list[stmt] - orelse: list[stmt] - type_comment: Optional[str] - - -@mlcd.py_class(type_key="mlc.ast.While", structure="bind") -class While(stmt): - test: expr - body: list[stmt] - orelse: list[stmt] - - -@mlcd.py_class(type_key="mlc.ast.If", structure="bind") -class If(stmt): - test: expr - body: list[stmt] - orelse: list[stmt] - - -@mlcd.py_class(type_key="mlc.ast.With", structure="bind") -class With(stmt): - items: list[withitem] - body: list[stmt] - type_comment: Optional[str] - - -@mlcd.py_class(type_key="mlc.ast.AsyncWith", structure="bind") -class AsyncWith(stmt): - items: list[withitem] - body: list[stmt] - type_comment: Optional[str] - - -@mlcd.py_class(type_key="mlc.ast.match_case", structure="bind") -class match_case(AST): - pattern: pattern - guard: Optional[expr] - body: list[stmt] - - -@mlcd.py_class(type_key="mlc.ast.Match", structure="bind") -class Match(stmt): - subject: expr - cases: list[match_case] - - -@mlcd.py_class(type_key="mlc.ast.Raise", structure="bind") -class Raise(stmt): - exc: Optional[expr] - cause: Optional[expr] - - -@mlcd.py_class(type_key="mlc.ast.ExceptHandler", structure="bind") -class ExceptHandler(AST): - lineno: Optional[int] - col_offset: Optional[int] - end_lineno: Optional[int] - end_col_offset: Optional[int] - type: Optional[expr] - name: Optional[_Identifier] - body: list[stmt] - - -@mlcd.py_class(type_key="mlc.ast.Try", structure="bind") -class Try(stmt): - body: list[stmt] - handlers: list[ExceptHandler] - orelse: list[stmt] - finalbody: list[stmt] - - -@mlcd.py_class(type_key="mlc.ast.TryStar", structure="bind") -class TryStar(stmt): - body: list[stmt] - handlers: list[ExceptHandler] - orelse: list[stmt] - finalbody: list[stmt] - - -@mlcd.py_class(type_key="mlc.ast.Assert", structure="bind") -class Assert(stmt): - test: expr - msg: Optional[expr] - - -@mlcd.py_class(type_key="mlc.ast.Import", structure="bind") -class Import(stmt): - names: list[alias] - - -@mlcd.py_class(type_key="mlc.ast.ImportFrom", structure="bind") -class ImportFrom(stmt): - module: Optional[str] - names: list[alias] - level: int - - -@mlcd.py_class(type_key="mlc.ast.Global", structure="bind") -class Global(stmt): - names: list[_Identifier] - - -@mlcd.py_class(type_key="mlc.ast.Nonlocal", structure="bind") -class Nonlocal(stmt): - names: list[_Identifier] - - -@mlcd.py_class(type_key="mlc.ast.Expr", structure="bind") -class Expr(stmt): - value: expr - - -@mlcd.py_class(type_key="mlc.ast.Pass", structure="bind") -class Pass(stmt): ... - - -@mlcd.py_class(type_key="mlc.ast.Break", structure="bind") -class Break(stmt): ... - - -@mlcd.py_class(type_key="mlc.ast.Continue", structure="bind") -class Continue(stmt): ... - - -@mlcd.py_class(type_key="mlc.ast.BoolOp", structure="bind") -class BoolOp(expr): - op: boolop - values: list[expr] - - -@mlcd.py_class(type_key="mlc.ast.Name", structure="bind") -class Name(expr): - id: _Identifier - ctx: expr_context - - -@mlcd.py_class(type_key="mlc.ast.NamedExpr", structure="bind") -class NamedExpr(expr): - target: Name - value: expr - - -@mlcd.py_class(type_key="mlc.ast.BinOp", structure="bind") -class BinOp(expr): - left: expr - op: operator - right: expr - - -@mlcd.py_class(type_key="mlc.ast.UnaryOp", structure="bind") -class UnaryOp(expr): - op: unaryop - operand: expr - - -@mlcd.py_class(type_key="mlc.ast.Lambda", structure="bind") -class Lambda(expr): - args: arguments - body: expr - - -@mlcd.py_class(type_key="mlc.ast.IfExp", structure="bind") -class IfExp(expr): - test: expr - body: expr - orelse: expr - - -@mlcd.py_class(type_key="mlc.ast.Dict", structure="bind") -class Dict(expr): - keys: list[Optional[expr]] - values: list[expr] - - -@mlcd.py_class(type_key="mlc.ast.Set", structure="bind") -class Set(expr): - elts: list[expr] - - -@mlcd.py_class(type_key="mlc.ast.ListComp", structure="bind") -class ListComp(expr): - elt: expr - generators: list[comprehension] - - -@mlcd.py_class(type_key="mlc.ast.SetComp", structure="bind") -class SetComp(expr): - elt: expr - generators: list[comprehension] - - -@mlcd.py_class(type_key="mlc.ast.DictComp", structure="bind") -class DictComp(expr): - key: expr - value: expr - generators: list[comprehension] - - -@mlcd.py_class(type_key="mlc.ast.GeneratorExp", structure="bind") -class GeneratorExp(expr): - elt: expr - generators: list[comprehension] - - -@mlcd.py_class(type_key="mlc.ast.Await", structure="bind") -class Await(expr): - value: expr - - -@mlcd.py_class(type_key="mlc.ast.Yield", structure="bind") -class Yield(expr): - value: Optional[expr] - - -@mlcd.py_class(type_key="mlc.ast.YieldFrom", structure="bind") -class YieldFrom(expr): - value: expr - - -@mlcd.py_class(type_key="mlc.ast.Compare", structure="bind") -class Compare(expr): - left: expr - ops: list[cmpop] - comparators: list[expr] - - -@mlcd.py_class(type_key="mlc.ast.Call", structure="bind") -class Call(expr): - func: expr - args: list[expr] - keywords: list[keyword] - - -@mlcd.py_class(type_key="mlc.ast.FormattedValue", structure="bind") -class FormattedValue(expr): - value: expr - conversion: int - format_spec: Optional[expr] - - -@mlcd.py_class(type_key="mlc.ast.JoinedStr", structure="bind") -class JoinedStr(expr): - values: list[expr] - - -@mlcd.py_class(type_key="mlc.ast.Ellipsis", structure="bind") -class Ellipsis(mlcd.PyClass): ... - - -@mlcd.py_class(type_key="mlc.ast.Constant", structure="bind") -class Constant(expr): - value: Any # None, str, bytes, bool, int, float, complex, Ellipsis - kind: Optional[str] - - -@mlcd.py_class(type_key="mlc.ast.Starred", structure="bind") -class Starred(expr): - value: expr - ctx: expr_context - - -@mlcd.py_class(type_key="mlc.ast.List", structure="bind") -class List(expr): - elts: list[expr] - ctx: expr_context - - -@mlcd.py_class(type_key="mlc.ast.Tuple", structure="bind") -class Tuple(expr): - elts: list[expr] - ctx: expr_context - dims: list[expr] - - -@mlcd.py_class(type_key="mlc.ast.Slice", structure="bind") -class Slice(expr): - lower: Optional[expr] - upper: Optional[expr] - step: Optional[expr] - - -@mlcd.py_class(type_key="mlc.ast.Load", structure="bind") -class Load(expr_context): ... - - -@mlcd.py_class(type_key="mlc.ast.Store", structure="bind") -class Store(expr_context): ... - - -@mlcd.py_class(type_key="mlc.ast.Del", structure="bind") -class Del(expr_context): ... - - -@mlcd.py_class(type_key="mlc.ast.And", structure="bind") -class And(boolop): ... - - -@mlcd.py_class(type_key="mlc.ast.Or", structure="bind") -class Or(boolop): ... - - -@mlcd.py_class(type_key="mlc.ast.Add", structure="bind") -class Add(operator): ... - - -@mlcd.py_class(type_key="mlc.ast.Sub", structure="bind") -class Sub(operator): ... - - -@mlcd.py_class(type_key="mlc.ast.Mult", structure="bind") -class Mult(operator): ... - - -@mlcd.py_class(type_key="mlc.ast.MatMult", structure="bind") -class MatMult(operator): ... - - -@mlcd.py_class(type_key="mlc.ast.Div", structure="bind") -class Div(operator): ... - - -@mlcd.py_class(type_key="mlc.ast.Mod", structure="bind") -class Mod(operator): ... - - -@mlcd.py_class(type_key="mlc.ast.Pow", structure="bind") -class Pow(operator): ... - - -@mlcd.py_class(type_key="mlc.ast.LShift", structure="bind") -class LShift(operator): ... - - -@mlcd.py_class(type_key="mlc.ast.RShift", structure="bind") -class RShift(operator): ... - - -@mlcd.py_class(type_key="mlc.ast.BitOr", structure="bind") -class BitOr(operator): ... - - -@mlcd.py_class(type_key="mlc.ast.BitXor", structure="bind") -class BitXor(operator): ... - - -@mlcd.py_class(type_key="mlc.ast.BitAnd", structure="bind") -class BitAnd(operator): ... - - -@mlcd.py_class(type_key="mlc.ast.FloorDiv", structure="bind") -class FloorDiv(operator): ... - - -@mlcd.py_class(type_key="mlc.ast.Invert", structure="bind") -class Invert(unaryop): ... - - -@mlcd.py_class(type_key="mlc.ast.Not", structure="bind") -class Not(unaryop): ... - - -@mlcd.py_class(type_key="mlc.ast.UAdd", structure="bind") -class UAdd(unaryop): ... - - -@mlcd.py_class(type_key="mlc.ast.USub", structure="bind") -class USub(unaryop): ... - - -@mlcd.py_class(type_key="mlc.ast.Eq", structure="bind") -class Eq(cmpop): ... - - -@mlcd.py_class(type_key="mlc.ast.NotEq", structure="bind") -class NotEq(cmpop): ... - - -@mlcd.py_class(type_key="mlc.ast.Lt", structure="bind") -class Lt(cmpop): ... - - -@mlcd.py_class(type_key="mlc.ast.LtE", structure="bind") -class LtE(cmpop): ... - - -@mlcd.py_class(type_key="mlc.ast.Gt", structure="bind") -class Gt(cmpop): ... - - -@mlcd.py_class(type_key="mlc.ast.GtE", structure="bind") -class GtE(cmpop): ... - - -@mlcd.py_class(type_key="mlc.ast.Is", structure="bind") -class Is(cmpop): ... - - -@mlcd.py_class(type_key="mlc.ast.IsNot", structure="bind") -class IsNot(cmpop): ... - - -@mlcd.py_class(type_key="mlc.ast.In", structure="bind") -class In(cmpop): ... - - -@mlcd.py_class(type_key="mlc.ast.NotIn", structure="bind") -class NotIn(cmpop): ... - - -@mlcd.py_class(type_key="mlc.ast.MatchValue", structure="bind") -class MatchValue(pattern): - value: expr - - -@mlcd.py_class(type_key="mlc.ast.MatchSingleton", structure="bind") -class MatchSingleton(pattern): - value: int # boolean - - -@mlcd.py_class(type_key="mlc.ast.MatchSequence", structure="bind") -class MatchSequence(pattern): - patterns: list[pattern] - - -@mlcd.py_class(type_key="mlc.ast.MatchMapping", structure="bind") -class MatchMapping(pattern): - keys: list[expr] - patterns: list[pattern] - rest: Optional[_Identifier] - - -@mlcd.py_class(type_key="mlc.ast.MatchClass", structure="bind") -class MatchClass(pattern): - cls: expr - patterns: list[pattern] - kwd_attrs: list[_Identifier] - kwd_patterns: list[pattern] - - -@mlcd.py_class(type_key="mlc.ast.MatchStar", structure="bind") -class MatchStar(pattern): - name: Optional[_Identifier] - - -@mlcd.py_class(type_key="mlc.ast.MatchAs", structure="bind") -class MatchAs(pattern): - pattern: Optional[pattern] - name: Optional[_Identifier] - - -@mlcd.py_class(type_key="mlc.ast.MatchOr", structure="bind") -class MatchOr(pattern): - patterns: list[pattern] - - -@mlcd.py_class(type_key="mlc.ast.TypeVar", structure="bind") -class TypeVar(type_param): - name: _Identifier - bound: Optional[expr] - default_value: Optional[expr] - - -@mlcd.py_class(type_key="mlc.ast.ParamSpec", structure="bind") -class ParamSpec(type_param): - name: _Identifier - default_value: Optional[expr] - - -@mlcd.py_class(type_key="mlc.ast.TypeVarTuple", structure="bind") -class TypeVarTuple(type_param): - name: _Identifier - default_value: Optional[expr] - - -@mlcd.py_class(type_key="mlc.ast.TypeAlias", structure="bind") -class TypeAlias(stmt): - name: Name - type_params: Optional[list[type_param]] - value: expr - - -__all__ = [ - "AST", - "mod", - "expr_context", - "operator", - "cmpop", - "unaryop", - "boolop", - "type_ignore", - "stmt", - "expr", - "type_param", - "pattern", - "arg", - "keyword", - "alias", - "arguments", - "comprehension", - "withitem", - "TypeIgnore", - "Module", - "Interactive", - "Expression", - "FunctionType", - "FunctionDef", - "AsyncFunctionDef", - "ClassDef", - "Return", - "Delete", - "Assign", - "Attribute", - "Subscript", - "AugAssign", - "AnnAssign", - "For", - "AsyncFor", - "While", - "If", - "With", - "AsyncWith", - "match_case", - "Match", - "Raise", - "ExceptHandler", - "Try", - "TryStar", - "Assert", - "Import", - "ImportFrom", - "Global", - "Nonlocal", - "Expr", - "Pass", - "Break", - "Continue", - "BoolOp", - "Name", - "NamedExpr", - "BinOp", - "UnaryOp", - "Lambda", - "IfExp", - "Dict", - "Set", - "ListComp", - "SetComp", - "DictComp", - "GeneratorExp", - "Await", - "Yield", - "YieldFrom", - "Compare", - "Call", - "FormattedValue", - "JoinedStr", - "Constant", - "Starred", - "List", - "Tuple", - "Slice", - "Load", - "Store", - "Del", - "And", - "Or", - "Add", - "Sub", - "Mult", - "MatMult", - "Div", - "Mod", - "Pow", - "LShift", - "RShift", - "BitOr", - "BitXor", - "BitAnd", - "FloorDiv", - "Invert", - "Not", - "UAdd", - "USub", - "Eq", - "NotEq", - "Lt", - "LtE", - "Gt", - "GtE", - "Is", - "IsNot", - "In", - "NotIn", - "MatchValue", - "MatchSingleton", - "MatchSequence", - "MatchMapping", - "MatchClass", - "MatchStar", - "MatchAs", - "MatchOr", - "TypeVar", - "ParamSpec", - "TypeVarTuple", - "TypeAlias", -] diff --git a/python/mlc/ast/translate.py b/python/mlc/ast/translate.py deleted file mode 100644 index cf4010ce..00000000 --- a/python/mlc/ast/translate.py +++ /dev/null @@ -1,109 +0,0 @@ -from __future__ import annotations - -import ast -import typing - -from mlc.core import List - -from . import mlc_ast - -TYPE_PY_TO_MLC_TRANSLATOR = typing.Callable[[ast.AST], mlc_ast.AST] -TYPE_MLC_TO_PY_TRANSLATOR = typing.Callable[[mlc_ast.AST], ast.AST] - - -def translate_py_to_mlc(py_node: ast.AST) -> mlc_ast.AST: - assert isinstance(py_node, ast.AST), f"Expected AST node, got {type(py_node)}" - if (translator := _PY_TO_MLC_VTABLE.get(type(py_node))) is not None: - return translator(py_node) - raise NotImplementedError(f"Translation not implemented: {type(py_node)}") - - -def translate_mlc_to_py(mlc_node: mlc_ast.AST) -> ast.AST: - assert isinstance(mlc_node, mlc_ast.AST), f"Expected AST node, got {type(mlc_node)}" - if (translator := _MLC_TO_PY_VTABLE.get(type(mlc_node))) is not None: - return translator(mlc_node) - raise NotImplementedError(f"Translation not implemented: {type(mlc_node)}") - - -def py_to_mlc_vtable_create() -> dict[type[ast.AST], TYPE_PY_TO_MLC_TRANSLATOR]: - def _translate_field(value: typing.Any) -> typing.Any: - if isinstance(value, ast.AST): - return translate_py_to_mlc(value) - elif value is ...: - return mlc_ast.Ellipsis() - elif isinstance(value, list): - return [_translate_field(item) for item in value] - return value - - def _create_entry( - mlc_cls: type[mlc_ast.AST], py_cls: type[ast.AST] - ) -> TYPE_PY_TO_MLC_TRANSLATOR: - def py_to_mlc_default(py_node: ast.AST) -> mlc_ast.AST: - return mlc_cls( - **{ - field: _translate_field(getattr(py_node, field, None)) - for field in mlc_field_names - } - ) - - mlc_field_names = list(typing.get_type_hints(mlc_cls).keys()) - return py_to_mlc_default - - def _create_entry_constant() -> TYPE_PY_TO_MLC_TRANSLATOR: - # Special handling to convert `bytea` to `str`, because `bytes` is not supported in MLC - entry = _create_entry(mlc_ast.Constant, ast.Constant) - - def py_to_mlc_constant(py_node: ast.Constant) -> mlc_ast.Constant: - if isinstance(py_node.value, bytes): - py_node.value = py_node.value.decode("utf-8") - return entry(py_node) - - return typing.cast(TYPE_PY_TO_MLC_TRANSLATOR, py_to_mlc_constant) - - vtable: dict[type[ast.AST], TYPE_PY_TO_MLC_TRANSLATOR] = { - type(...): lambda _: mlc_ast.Ellipsis(), # type: ignore - ast.Constant: _create_entry_constant(), - } - for cls_name in mlc_ast.__all__: - mlc_cls = getattr(mlc_ast, cls_name) - if py_cls := getattr(ast, cls_name, None): - if py_cls not in vtable: - vtable[py_cls] = _create_entry(mlc_cls, py_cls) - return vtable - - -def mlc_to_py_vtable_create() -> dict[type[mlc_ast.AST], TYPE_MLC_TO_PY_TRANSLATOR]: - def _translate_field(value: typing.Any) -> typing.Any: - if isinstance(value, mlc_ast.AST): - return translate_mlc_to_py(value) - elif isinstance(value, (list, List)): - return [_translate_field(item) for item in value] - return value - - def _create_entry( - py_cls: type[ast.AST], mlc_cls: type[mlc_ast.AST] - ) -> TYPE_MLC_TO_PY_TRANSLATOR: - def mlc_to_py_default(mlc_node: mlc_ast.AST) -> ast.AST: - return py_cls( - **{ - field: _translate_field(getattr(mlc_node, field, None)) - for field in py_field_names - } - ) - - py_field_names = list(typing.get_type_hints(mlc_cls).keys()) - return mlc_to_py_default - - vtable: dict[type[mlc_ast.AST], TYPE_MLC_TO_PY_TRANSLATOR] = { - mlc_ast.Ellipsis: lambda _: ..., # type: ignore - } - for cls_name in mlc_ast.__all__: - mlc_cls = getattr(mlc_ast, cls_name) - if py_cls := getattr(ast, cls_name, None): - if mlc_cls not in vtable: - vtable[mlc_cls] = _create_entry(py_cls, mlc_cls) - return vtable - - -_PY_TO_MLC_VTABLE = py_to_mlc_vtable_create() -_MLC_TO_PY_VTABLE = mlc_to_py_vtable_create() diff --git a/python/mlc/core/typing.py b/python/mlc/core/typing.py index 7b38ade8..a3564563 100644 --- a/python/mlc/core/typing.py +++ b/python/mlc/core/typing.py @@ -63,9 +63,10 @@ def _ctype(self) -> typing.Any: @c_class_core("mlc.core.typing.PtrType") class PtrType(Type): - @property - def ty(self) -> Type: - return self._C(b"_ty", self) + ty: Type + + def __init__(self, ty: Type) -> None: + self._mlc_init(ty) def args(self) -> tuple: return (self.ty,) @@ -76,13 +77,11 @@ def _ctype(self) -> typing.Any: @c_class_core("mlc.core.typing.Optional") class Optional(Type): + ty: Type + def __init__(self, ty: Type) -> None: self._mlc_init(ty) - @property - def ty(self) -> Type: - return self._C(b"_ty", self) - def args(self) -> tuple: return (self.ty,) @@ -92,13 +91,11 @@ def _ctype(self) -> typing.Any: @c_class_core("mlc.core.typing.List") class List(Type): + ty: Type + def __init__(self, ty: Type) -> None: self._mlc_init(ty) - @property - def ty(self) -> Type: - return self._C(b"_ty", self) - def args(self) -> tuple: return (self.ty,) @@ -108,19 +105,14 @@ def _ctype(self) -> typing.Any: @c_class_core("mlc.core.typing.Dict") class Dict(Type): - def __init__(self, key_ty: Type, value_ty: Type) -> None: - self._mlc_init(key_ty, value_ty) - - @property - def key(self) -> Type: - return self._C(b"_key", self) + ty_k: Type + ty_v: Type - @property - def value(self) -> Type: - return self._C(b"_value", self) + def __init__(self, ty_k: Type, ty_v: Type) -> None: + self._mlc_init(ty_k, ty_v) def args(self) -> tuple: - return (self.key, self.value) + return (self.ty_k, self.ty_v) def _ctype(self) -> typing.Any: return MLCObjPtr diff --git a/python/mlc/dataclasses/utils.py b/python/mlc/dataclasses/utils.py index 91b94311..1fa3a37d 100644 --- a/python/mlc/dataclasses/utils.py +++ b/python/mlc/dataclasses/utils.py @@ -78,13 +78,16 @@ def bind_args(*args: typing.Any, **kwargs: typing.Any) -> inspect.BoundArguments return bound def method(self: type, *args: typing.Any, **kwargs: typing.Any) -> None: + e = None try: args = bind_args(*args, **kwargs).args args = tuple(arg.fn() if isinstance(arg, DefaultFactory) else arg for arg in args) args = tuple(args[order] for order in ordering) self._mlc_init(*args) # type: ignore[attr-defined] - except Exception as e: - raise TypeError(f"Error in `{signature_str}`: {e}") # type: ignore[attr-defined] + except Exception as _e: + e = TypeError(f"Error in `{signature_str}`: {_e}").with_traceback(_e.__traceback__) + if e is not None: + raise e try: post_init = self.__post_init__ # type: ignore[attr-defined] except AttributeError: diff --git a/python/mlc/parser/__init__.py b/python/mlc/parser/__init__.py new file mode 100644 index 00000000..2ecaa846 --- /dev/null +++ b/python/mlc/parser/__init__.py @@ -0,0 +1,3 @@ +from .diagnostic import DiagnosticError +from .env import Env, Span, check_decorator +from .parser import Frame, Parser diff --git a/python/mlc/parser/diagnostic.py b/python/mlc/parser/diagnostic.py new file mode 100644 index 00000000..42742d02 --- /dev/null +++ b/python/mlc/parser/diagnostic.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import os +import sys +from types import TracebackType +from typing import Any, Literal + +import colorama # type: ignore[import-untyped] + +from .env import Span + +colorama.init(autoreset=True) + + +class DiagnosticError(Exception): + pass + + +def raise_diagnostic_error( + source: str, + source_name: str, + span: Span, + err: Exception | str, +) -> None: + diag_err = DiagnosticError("Diagnostics were emitted, please check rendered error message.") + if isinstance(err, Exception): + msg = type(err).__name__ + ": " + [i for i in str(err).split("\n") if i][-1] + diag_err.with_traceback(err.__traceback__) + else: + msg = str(err) + print( + _render_at( + source=source, + source_name=source_name, + span=span, + message=msg, + level="error", + ), + file=sys.stderr, + ) + if isinstance(err, Exception): + raise diag_err from err + else: + raise diag_err + + +def _render_at( + source: str, + source_name: str, + span: Span, + message: str, + level: Literal["warning", "error", "bug", "note"] = "error", +) -> str: + lines = source.splitlines() + row_st = max(1, span.row_st) + row_ed = min(len(lines), span.row_ed) + # If no valid rows, just return the bare message. + if row_st > row_ed: + return message + # Map the "level" to a color and label (similar to rang::fg usage in C++). + color, diag_type = { + "warning": (colorama.Fore.YELLOW, "warning"), + "error": (colorama.Fore.RED, "error"), + "bug": (colorama.Fore.BLUE, "bug"), + "note": (colorama.Fore.RESET, "note"), + "help": (colorama.Fore.RESET, "help"), + }.get(level, (colorama.Fore.RED, "error")) + + # Prepare lines of output + out_lines = [ + f"{colorama.Style.BRIGHT}{color}{diag_type}{colorama.Style.RESET_ALL}: {message}", + f"{colorama.Fore.BLUE} --> {colorama.Style.RESET_ALL}{source_name}:{row_st}:{span.col_st}", + ] + left_margin_width = len(str(row_ed)) + for row_idx in range(row_st, row_ed + 1): + line_text = lines[row_idx - 1] # zero-based + line_label = str(row_idx).rjust(left_margin_width) + # Step 1. the actual source line + out_lines.append(f"{line_label} | {line_text}") + # Step 2. the marker line: + marker = [" "] * len(line_text) + # For the first line... + if row_idx == row_st and row_idx == row_ed: + # Case 1. Single-line: highlight col_st..col_ed + c_start = max(1, span.col_st) + c_end = min(len(line_text), span.col_ed) + marker[c_start:c_end] = "^" * (c_end - c_start) + elif row_idx == row_st: + # Case 2. The first line in a multi-line highlight + c_start = max(1, span.col_st) + marker[c_start:] = "^" * (len(line_text) - c_start) + elif row_idx == row_ed: + # Case 3. The last line in a multi-line highlight + c_end = min(len(line_text), span.col_ed) + marker[:c_end] = "^" * c_end + else: + # Case 4. A line in the middle of row_st..row_ed => highlight entire line + marker = ["^"] * len(line_text) + out_lines.append(f"{' ' * (left_margin_width)} | {''.join(marker)}") + return "\n".join(out_lines) + + +def excepthook( + exctype: type[BaseException], + value: BaseException, + traceback: TracebackType | None, +) -> Any: + should_hide_backtrace = os.environ.get("MLC_BACKTRACE", None) is None + if exctype is DiagnosticError and should_hide_backtrace: + print("note: run with `MLC_BACKTRACE=1` environment variable to display a backtrace.") + return + sys_excepthook(exctype, value, traceback) + + +sys_excepthook = sys.excepthook +sys.excepthook = excepthook diff --git a/python/mlc/ast/inspection.py b/python/mlc/parser/env.py similarity index 60% rename from python/mlc/ast/inspection.py rename to python/mlc/parser/env.py index cb1ddb39..f273f182 100644 --- a/python/mlc/ast/inspection.py +++ b/python/mlc/parser/env.py @@ -4,14 +4,17 @@ import dataclasses import inspect from collections.abc import Callable, Generator -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + import ast PY_GETFILE = inspect.getfile PY_FINDSOURCE = inspect.findsource -@dataclasses.dataclass(init=False) -class InspectResult: +@dataclasses.dataclass +class Env: source_name: str source_start_line: int source_start_column: int @@ -20,65 +23,89 @@ class InspectResult: captured: dict[str, Any] annotations: dict[str, dict[str, Any]] - def is_defined_in_class( - self, - frames: list[inspect.FrameInfo], # obtained via `inspect.stack()` - *, - frame_offset: int = 2, - is_decorator: Callable[[str], bool] = lambda line: line.startswith("@"), - ) -> bool: - # Step 1. Inspect `frames[frame_offset]` - try: - lineno = frames[frame_offset].lineno - line = frames[frame_offset].code_context[0].strip() # type: ignore - except: - return False - # Step 2. Determine by the line itself - if is_decorator(line): - return True - if not line.startswith("class"): - return False - # Step 3. Determine by its decorators - source_lines = self.source_full.splitlines(keepends=True) - lineno_offset = 2 - try: - source_line = source_lines[lineno - lineno_offset] - except IndexError: - return False - return is_decorator(source_line.strip()) - - -def inspect_program(program: Callable | type) -> InspectResult: - ret = InspectResult() - source = inspect_source(program) - ret.source_name, ret.source_start_line, ret.source_start_column, ret.source, ret.source_full = ( - source.source_name, - source.source_start_line, - source.source_start_column, - source.source, - source.source_full, - ) - if inspect.isfunction(program): - ret.captured = inspect_capture_function(program) - ret.annotations = inspect_annotations_function(program) - elif inspect.isclass(program): - ret.captured = inspect_capture_class(program) - ret.annotations = inspect_annotations_class(program) - else: - raise TypeError(f"{program!r} is not a function or class") - return ret + @staticmethod + def from_class(program: type) -> Env: + return Env( + **_inspect_source(program), # type: ignore[arg-type] + captured=_inspect_capture_class(program), + annotations=_inspect_annotations_class(program), + ) + + @staticmethod + def from_function(program: Callable) -> Env: + return Env( + **_inspect_source(program), # type: ignore[arg-type] + captured=_inspect_capture_function(program), + annotations=_inspect_annotations_function(program), + ) + + +@dataclasses.dataclass +class Span: + row_st: int + row_ed: int + col_st: int + col_ed: int + + @staticmethod + def from_ast(node: ast.AST, env: Env | None = None) -> Span: + row_st: int = getattr(node, "lineno", None) or 1 + row_ed: int = getattr(node, "end_lineno", None) or row_st + col_st: int = getattr(node, "col_offset", None) or 1 + col_ed: int = getattr(node, "end_col_offset", None) or col_st + if env is not None: + row_st += env.source_start_line - 1 + row_ed += env.source_start_line - 1 + col_st += env.source_start_column + col_ed += env.source_start_column + return Span( + row_st=row_st, + row_ed=row_ed, + col_st=col_st, + col_ed=col_ed, + ) + + +def check_decorator( + frames: list[inspect.FrameInfo], # obtained via `inspect.stack()` + *, + source_full: str | None = None, + frame_offset: int = 2, + checker: Callable[[str], bool] = lambda line: line.startswith("@"), +) -> bool: + # Step 1. Inspect `frames[frame_offset]` + try: + lineno = frames[frame_offset].lineno + line = frames[frame_offset].code_context[0] # type: ignore + except: + return False + # Step 2. Determine by the line itself + if checker(line): + return True + if not line.startswith("class"): + return False + # Step 3. Determine by its decorators + if source_full is None: + return False + source_lines = source_full.splitlines(keepends=True) + lineno_offset = 2 + try: + source_line = source_lines[lineno - lineno_offset] + except IndexError: + return False + return checker(source_line) @contextlib.contextmanager -def override_getfile() -> Generator[None, Any, None]: +def _override_getfile() -> Generator[None, Any, None]: try: - inspect.getfile = getfile # type: ignore[assignment] + inspect.getfile = _getfile # type: ignore[assignment] yield finally: inspect.getfile = PY_GETFILE # type: ignore[assignment] -def getfile(obj: Any) -> str: +def _getfile(obj: Any) -> str: if not inspect.isclass(obj): return PY_GETFILE(obj) mod = getattr(obj, "__module__", None) @@ -92,16 +119,16 @@ def getfile(obj: Any) -> str: if inspect.isfunction(member): if obj.__qualname__ + "." + member.__name__ == member.__qualname__: return inspect.getfile(member) - raise TypeError(f"Source for {obj:!r} not found") + raise TypeError(f"Source for {obj} not found") -def getsourcelines(obj: Any) -> tuple[list[str], int]: +def _getsourcelines(obj: Any) -> tuple[list[str], int]: obj = inspect.unwrap(obj) - lines, l_num = findsource(obj) + lines, l_num = _findsource(obj) return inspect.getblock(lines[l_num:]), l_num + 1 -def findsource(obj: Any) -> tuple[list[str], int]: # noqa: PLR0912 +def _findsource(obj: Any) -> tuple[list[str], int]: # noqa: PLR0912 if not inspect.isclass(obj): return PY_FINDSOURCE(obj) @@ -154,19 +181,10 @@ def findsource(obj: Any) -> tuple[list[str], int]: # noqa: PLR0912 raise OSError("could not find class definition") -@dataclasses.dataclass -class Source: - source_name: str - source_start_line: int - source_start_column: int - source: str - source_full: str - - -def inspect_source(program: Callable | type) -> Source: - with override_getfile(): +def _inspect_source(program: Callable | type) -> dict[str, int | str]: + with _override_getfile(): source_name: str = inspect.getsourcefile(program) # type: ignore - lines, source_start_line = getsourcelines(program) # type: ignore + lines, source_start_line = _getsourcelines(program) # type: ignore if lines: source_start_column = len(lines[0]) - len(lines[0].lstrip()) else: @@ -190,20 +208,20 @@ def inspect_source(program: Callable | type) -> Source: # as a fallback method. src, _ = inspect.findsource(program) # type: ignore source_full = "".join(src) - return Source( - source_name=source_name, - source_start_line=source_start_line, - source_start_column=source_start_column, - source=source, - source_full=source_full, - ) + return dict( + source_name=source_name, + source_start_line=source_start_line, + source_start_column=source_start_column, + source=source, + source_full=source_full, + ) -def inspect_annotations_function(program: Callable | type) -> dict[str, dict[str, Any]]: +def _inspect_annotations_function(program: Callable | type) -> dict[str, dict[str, Any]]: return {program.__name__: program.__annotations__} -def inspect_annotations_class(program: Callable | type) -> dict[str, dict[str, Any]]: +def _inspect_annotations_class(program: Callable | type) -> dict[str, dict[str, Any]]: annotations = {} for name, func in program.__dict__.items(): if inspect.isfunction(func): @@ -211,13 +229,13 @@ def inspect_annotations_class(program: Callable | type) -> dict[str, dict[str, A return annotations -def inspect_capture_function(func: Callable) -> dict[str, Any]: +def _inspect_capture_function(func: Callable) -> dict[str, Any]: def _getclosurevars(func: Callable) -> dict[str, Any]: # Mofiied from `inspect.getclosurevars` if inspect.ismethod(func): func = func.__func__ if not inspect.isfunction(func): - raise TypeError(f"{func!r} is not a Python function") + raise TypeError(f"Not a Python function: {func}") code = func.__code__ # Nonlocal references are named in co_freevars and resolved # by looking them up in __closure__ by positional index @@ -238,10 +256,10 @@ def _getclosurevars(func: Callable) -> dict[str, Any]: } -def inspect_capture_class(cls: type) -> dict[str, Any]: +def _inspect_capture_class(cls: type) -> dict[str, Any]: result: dict[str, Any] = {} for _, v in cls.__dict__.items(): if inspect.isfunction(v): - func_vars = inspect_capture_function(v) + func_vars = _inspect_capture_function(v) result.update(**func_vars) return result diff --git a/python/mlc/parser/parser.py b/python/mlc/parser/parser.py new file mode 100644 index 00000000..944b716d --- /dev/null +++ b/python/mlc/parser/parser.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +import ast +import builtins +import copy +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any + +from .diagnostic import DiagnosticError, raise_diagnostic_error +from .env import Env, Span + +Value = Any + + +class Frame: + name2value: dict[str, Value] + + def __init__(self) -> None: + self.name2value = {} + + def add(self, name: str, value: Value, override: bool = False) -> None: + if name in self.name2value and not override: + raise ValueError(f"Variable already defined in current scope: {name}") + self.name2value[name] = value + + @contextmanager + def scope(self, parser: Parser) -> Generator[Frame, None, None]: + parser.frames.append(self) + try: + yield self + finally: + parser.frames.pop() + + +class Parser: + env: Env + frames: list[Frame] + + def __init__( + self, + env: Env, + extra_vars: dict[str, Value] | None = None, + include_builtins: bool = True, + ) -> None: + self.env = env + self.frames = [Frame()] + if include_builtins: + for name, value in builtins.__dict__.items(): + self.var_def(name, value) + if extra_vars is not None: + for name, value in extra_vars.items(): + self.var_def(name, value) + + def var_def( + self, + name: str, + value: Value, + frame: Frame | None = None, + override: bool = False, + ) -> None: + if frame is None: + frame = self.frames[-1] + frame.add(name, value, override=override) + + def report_error(self, node: ast.AST | Span, err: Exception | str) -> None: + if isinstance(err, DiagnosticError): + raise err + span = Span.from_ast(node, self.env) if isinstance(node, ast.AST) else node + raise_diagnostic_error( + source=self.env.source_full, + source_name=self.env.source_name, + span=span, + err=err, + ) + + def eval_expr(self, node: ast.expr) -> Value: + var_tab: dict[str, Value] = {} + for frame in self.frames: + var_tab.update({name: value for name, value in frame.name2value.items()}) + return ExprEvaluator(parser=self, var_tab=var_tab).run(node) + + def eval_assign(self, target: ast.expr, source: Value) -> dict[str, Value]: + var_name = "__mlc_rhs_var__" + dict_locals: dict[str, Value] = {var_name: source} + mod = ast.fix_missing_locations( + ast.Module( + body=[ + ast.Assign( + targets=[target], + value=ast.Name(id=var_name, ctx=ast.Load()), + ) + ], + type_ignores=[], + ) + ) + exe = compile(mod, filename="", mode="exec") + exec(exe, {}, dict_locals) # pylint: disable=exec-used + del dict_locals[var_name] + return dict_locals + + +class ExprEvaluator(ast.NodeTransformer): + parser: Parser + var_tab: dict[str, Value] + num_tmp_vars: int + + def __init__(self, parser: Parser, var_tab: dict[str, Value]) -> None: + self.parser = parser + self.var_tab = var_tab + self.num_tmp_vars = 0 + + def run(self, node: ast.AST) -> Value: + var_name = "_mlc_evaluator_tmp_var_result" + self._eval_subexpr( + expr=self.visit(node), + target=ast.Name(id=var_name, ctx=ast.Store()), + ) + return self.var_tab[var_name] + + def _make_intermediate_var(self) -> ast.Name: + var_name = f"_mlc_evaluator_tmp_var_{self.num_tmp_vars}" + self.num_tmp_vars += 1 + return ast.Name(id=var_name, ctx=ast.Store()) + + def visit_Name(self, node: ast.Name) -> ast.Name: + if node.id not in self.var_tab: + self.parser.report_error(node, f"Undefined variable: {node.id}") + return node + + def visit_Call(self, node: ast.Call) -> ast.Name: + node.func = self.visit(node.func) + node.args = [self.visit(arg) for arg in node.args] + node.keywords = [self.visit(kw) for kw in node.keywords] + return self._eval_subexpr( + expr=node, + target=self._make_intermediate_var(), + ) + + def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.Name: + node.operand = self.visit(node.operand) + return self._eval_subexpr( + expr=node, + target=self._make_intermediate_var(), + ) + + def visit_BinOp(self, node: ast.BinOp) -> ast.Name: + node.left = self.visit(node.left) + node.right = self.visit(node.right) + return self._eval_subexpr( + expr=node, + target=self._make_intermediate_var(), + ) + + def visit_Compare(self, node: ast.Compare) -> ast.Name: + node.left = self.visit(node.left) + node.comparators = [self.visit(comp) for comp in node.comparators] + return self._eval_subexpr( + expr=node, + target=self._make_intermediate_var(), + ) + + def visit_BoolOp(self, node: ast.BoolOp) -> ast.Name: + node.values = [self.visit(value) for value in node.values] + return self._eval_subexpr( + expr=node, + target=self._make_intermediate_var(), + ) + + def visit_IfExp(self, node: ast.IfExp) -> ast.Name: + node.test = self.visit(node.test) + node.body = self.visit(node.body) + node.orelse = self.visit(node.orelse) + return self._eval_subexpr( + expr=node, + target=self._make_intermediate_var(), + ) + + def visit_Attribute(self, node: ast.Attribute) -> ast.Name: + node.value = self.visit(node.value) + return self._eval_subexpr( + expr=node, + target=self._make_intermediate_var(), + ) + + def visit_Subscript(self, node: ast.Subscript) -> ast.Name: + node.value = self.visit(node.value) + node.slice = self.visit(node.slice) + return self._eval_subexpr( + expr=node, + target=self._make_intermediate_var(), + ) + + def visit_Slice(self, node: ast.Slice) -> ast.Name: + if node.lower is not None: + node.lower = self.visit(node.lower) + if node.upper is not None: + node.upper = self.visit(node.upper) + if node.step is not None: + node.step = self.visit(node.step) + return self._eval_subexpr( + expr=node, + target=self._make_intermediate_var(), + ) + + def _eval_subexpr(self, expr: ast.expr, target: ast.Name) -> ast.Name: + target.ctx = ast.Store() + mod = ast.fix_missing_locations( + ast.Module( + body=[ast.Assign(targets=[target], value=copy.copy(expr))], # target = node + type_ignores=[], + ) + ) + exe = compile(mod, filename="", mode="exec") + try: + exec(exe, {}, self.var_tab) + except Exception as err: + self.parser.report_error(node=expr, err=err) + target.ctx = ast.Load() + return target diff --git a/python/mlc/printer/ir_printer.py b/python/mlc/printer/ir_printer.py index da9603ab..37cbe6db 100644 --- a/python/mlc/printer/ir_printer.py +++ b/python/mlc/printer/ir_printer.py @@ -5,7 +5,7 @@ import mlc.dataclasses as mlcd from mlc.core import Func, Object, ObjectPath -from .ast import Id, Node, PrinterConfig, Stmt +from .ast import Expr, Node, PrinterConfig, Stmt from .cprint import cprint @@ -44,16 +44,26 @@ def __init__(self, cfg: Optional[PrinterConfig] = None) -> None: def var_is_defined(self, obj: Any) -> bool: return bool(IRPrinter._C(b"var_is_defined", self, obj)) - def var_def(self, obj: Any, frame: Any, name: str) -> Id: - return IRPrinter._C(b"var_def", self, obj, frame, name) - - def var_def_no_name(self, obj: Any, creator: Func) -> None: - IRPrinter._C(b"var_def_no_name", self, obj, creator) + def var_def( + self, + name: str, + obj: Any, + frame: Optional[Any] = None, + ) -> None: + return IRPrinter._C(b"var_def", self, name, obj, frame) + + def var_def_no_name( + self, + creator: Func, + obj: Any, + frame: Optional[Any] = None, + ) -> None: + IRPrinter._C(b"var_def_no_name", self, creator, obj, frame) def var_remove(self, obj: Any) -> None: IRPrinter._C(b"var_remove", self, obj) - def var_get(self, obj: Any) -> Optional[Id]: + def var_get(self, obj: Any) -> Optional[Expr]: return IRPrinter._C(b"var_get", self, obj) def frame_push(self, frame: Any) -> None: diff --git a/python/mlc/testing/__init__.py b/python/mlc/testing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/mlc/testing/toy_ir/__init__.py b/python/mlc/testing/toy_ir/__init__.py new file mode 100644 index 00000000..61206a55 --- /dev/null +++ b/python/mlc/testing/toy_ir/__init__.py @@ -0,0 +1,3 @@ +from .ir import Add, Assign, Expr, Func, Node, Stmt, Var +from .ir_builder import FunctionFrame, IRBuilder +from .parser import Parser, parse_func diff --git a/python/mlc/testing/toy_ir/ir.py b/python/mlc/testing/toy_ir/ir.py new file mode 100644 index 00000000..5f2e7ad1 --- /dev/null +++ b/python/mlc/testing/toy_ir/ir.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import mlc.dataclasses as mlcd +import mlc.printer as mlcp +import mlc.printer.ast as mlt + + +@mlcd.py_class +class Node(mlcd.PyClass): ... + + +@mlcd.py_class +class Expr(Node): ... + + +@mlcd.py_class +class Stmt(Node): ... + + +@mlcd.py_class(structure="var") +class Var(Expr): + name: str = mlcd.field(structure=None) + + def __add__(self, other: Var) -> Add: + return Add(lhs=self, rhs=other) + + def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node: + if not printer.var_is_defined(obj=self): + printer.var_def(self.name, obj=self) + ret = printer.var_get(obj=self) + assert ret is not None + return ret + + +@mlcd.py_class(structure="nobind") +class Add(Expr): + lhs: Expr + rhs: Expr + + def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node: + lhs: mlt.Expr = printer(self.lhs, path=path["a"]) + rhs: mlt.Expr = printer(self.rhs, path=path["b"]) + return lhs + rhs + + +@mlcd.py_class(structure="bind") +class Assign(Stmt): + rhs: Expr + lhs: Var = mlcd.field(structure="bind") + + def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node: + rhs: mlt.Expr = printer(self.rhs, path=path["b"]) + printer.var_def(self.lhs.name, obj=self.lhs) + lhs: mlt.Expr = printer(self.lhs, path=path["a"]) + return mlt.Assign(lhs=lhs, rhs=rhs) + + +@mlcd.py_class(structure="bind") +class Func(Node): + name: str = mlcd.field(structure=None) + args: list[Var] = mlcd.field(structure="bind") + stmts: list[Stmt] + ret: Var + + def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node: + with printer.with_frame(mlcp.DefaultFrame()): + for arg in self.args: + printer.var_def(arg.name, obj=arg) + args: list[mlt.Expr] = [ + printer(arg, path=path["args"][i]) for i, arg in enumerate(self.args) + ] + stmts: list[mlt.Expr] = [ + printer(stmt, path=path["stmts"][i]) for i, stmt in enumerate(self.stmts) + ] + ret_stmt = mlt.Return(printer(self.ret, path=path["ret"])) + return mlt.Function( + name=mlt.Id(self.name), + args=[mlt.Assign(lhs=arg, rhs=None) for arg in args], + decorators=[], + return_type=None, + body=[*stmts, ret_stmt], + ) diff --git a/python/mlc/testing/toy_ir/ir_builder.py b/python/mlc/testing/toy_ir/ir_builder.py new file mode 100644 index 00000000..797f2a76 --- /dev/null +++ b/python/mlc/testing/toy_ir/ir_builder.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from types import TracebackType +from typing import Any, ClassVar + +from .ir import Func, Stmt, Var + + +class IRBuilder: + _ctx: ClassVar[list[IRBuilder]] = [] + frames: list[Any] + result: Any + + def __init__(self) -> None: + self.frames = [] + self.result = None + + def __enter__(self) -> IRBuilder: + IRBuilder._ctx.append(self) + return self + + def __exit__( + self, + exc_type: type[BaseException], + exc_value: BaseException, + traceback: TracebackType, + ) -> None: + IRBuilder._ctx.pop() + + @staticmethod + def get() -> IRBuilder: + return IRBuilder._ctx[-1] + + +class FunctionFrame: + name: str + args: list[Var] + stmts: list[Stmt] + ret: Var | None + + def __init__(self, name: str) -> None: + self.name = name + self.args = [] + self.stmts = [] + self.ret = None + + def add_arg(self, arg: Var) -> Var: + self.args.append(arg) + return arg + + def __enter__(self) -> FunctionFrame: + IRBuilder.get().frames.append(self) + return self + + def __exit__( + self, + exc_type: type[BaseException], + exc_value: BaseException, + traceback: TracebackType, + ) -> None: + frame = IRBuilder.get().frames.pop() + assert frame is self + if exc_type is None: + IRBuilder.get().result = Func( + name=self.name, + args=frame.args, + stmts=frame.stmts, + ret=frame.ret, + ) diff --git a/python/mlc/testing/toy_ir/parser.py b/python/mlc/testing/toy_ir/parser.py new file mode 100644 index 00000000..5e9012eb --- /dev/null +++ b/python/mlc/testing/toy_ir/parser.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import ast +from typing import Callable + +import mlc.parser as mlcs + +from .ir import Assign, Node, Var +from .ir_builder import FunctionFrame, IRBuilder + + +class Parser(ast.NodeVisitor): + base: mlcs.Parser + + def __init__(self, env: mlcs.Env) -> None: + self.base = mlcs.Parser(env, include_builtins=True, extra_vars=None) + + def visit_Assign(self, node: ast.Assign) -> None: + if len(node.targets) != 1: + self.base.report_error(node, "Continuous assignment is not supported") + (target,) = node.targets + if not isinstance(target, ast.Name): + self.base.report_error(target, "Invalid assignment target") + assert isinstance(target, ast.Name) + value = self.base.eval_assign( + target=target, + source=self.base.eval_expr(node.value), + )[target.id] + var = Var(name=target.id) + self.base.var_def(name=target.id, value=var) + IRBuilder.get().frames[-1].stmts.append( + Assign( + lhs=var, + rhs=value, + ) + ) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + with mlcs.Frame().scope(self.base): + with FunctionFrame(node.name) as frame: + for node_arg in node.args.args: + self.base.var_def( + name=node_arg.arg, + value=frame.add_arg(Var(name=node_arg.arg)), + ) + for node_stmt in node.body: + self.visit(node_stmt) + + def visit_Return(self, node: ast.Return) -> None: + if not isinstance(node.value, ast.Name): + self.base.report_error(node, "Return statement must return a single variable") + assert isinstance(node.value, ast.Name) + frame: FunctionFrame = IRBuilder.get().frames[-1] + frame.ret = self.base.eval_expr(node.value) + + +def parse_func(source: Callable) -> Node: + env = mlcs.Env.from_function(source) + parser = Parser(env) + node = ast.parse( + env.source, + filename=env.source_name, + ) + with IRBuilder() as ib: + parser.visit(node) + return ib.result diff --git a/tests/cpp/test_base_optional.cc b/tests/cpp/test_base_optional.cc index df855185..c166b762 100644 --- a/tests/cpp/test_base_optional.cc +++ b/tests/cpp/test_base_optional.cc @@ -274,7 +274,7 @@ TEST(OptionalAnyConversion, ObjectRefType) { // Construct from AnyView Tests TEST(OptionalConstructFromAnyView, IntType) { AnyView view(42); - Optional opt_int(view); + Optional opt_int = view.operator Optional(); // TODO: recover implicit casting EXPECT_TRUE(opt_int.defined()); EXPECT_EQ(*opt_int, 42); } @@ -282,7 +282,7 @@ TEST(OptionalConstructFromAnyView, IntType) { TEST(OptionalConstructFromAnyView, ObjectRefType) { TestObjRef obj(Ref::New(10)); AnyView obj_view(obj); - Optional opt_obj(obj_view); + Optional opt_obj = obj_view.operator Optional(); // TODO: recover implicit casting EXPECT_TRUE(opt_obj.defined()); EXPECT_EQ(opt_obj->value, 10); } @@ -290,7 +290,7 @@ TEST(OptionalConstructFromAnyView, ObjectRefType) { // Construct from Any Tests TEST(OptionalConstructFromAny, IntType) { Any any(42); - Optional opt_int(any); + Optional opt_int = any.operator Optional(); // TODO: recover implicit casting EXPECT_TRUE(opt_int.defined()); EXPECT_EQ(*opt_int, 42); } @@ -298,7 +298,7 @@ TEST(OptionalConstructFromAny, IntType) { TEST(OptionalConstructFromAny, ObjectRefType) { TestObjRef obj(Ref::New(10)); Any obj_any(obj); - Optional opt_obj(obj_any); + Optional opt_obj = obj_any.operator Optional(); // TODO: recover implicit casting EXPECT_TRUE(opt_obj.defined()); EXPECT_EQ(opt_obj->value, 10); } diff --git a/tests/cpp/test_base_ref.cc b/tests/cpp/test_base_ref.cc index 36e425ee..4205c086 100644 --- a/tests/cpp/test_base_ref.cc +++ b/tests/cpp/test_base_ref.cc @@ -303,7 +303,7 @@ TEST(RefPOD, ConversionToAny) { Any any = ref; ref.Reset(); EXPECT_EQ(any.operator int64_t(), 42); - ref = Ref(any); + ref = any.operator Ref(); EXPECT_EQ(*ref, 42); EXPECT_EQ(GetRefCount(ref), 1); } @@ -319,7 +319,7 @@ TEST(RefPOD, ConversionToAnyView) { AnyView any_view = ref; ref.Reset(); EXPECT_EQ(any_view.operator int64_t(), 42); - ref = Ref(any_view); + ref = any_view.operator Ref(); EXPECT_EQ(*ref, 42); EXPECT_EQ(GetRefCount(ref), 1); } diff --git a/tests/python/test_parser_toy_ir_parser.py b/tests/python/test_parser_toy_ir_parser.py new file mode 100644 index 00000000..3c502a15 --- /dev/null +++ b/tests/python/test_parser_toy_ir_parser.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from mlc.testing import toy_ir +from mlc.testing.toy_ir import Add, Assign, Func, Var + + +def test_parse_func() -> None: + def source_code(a, b, c): # noqa: ANN001, ANN202 + d = a + b + e = d + c + return e + + def _expected() -> Func: + a = Var(name="_a") + b = Var(name="_b") + c = Var(name="_c") + d = Var(name="_d") + e = Var(name="_e") + stmts = [ + Assign(lhs=d, rhs=Add(a, b)), + Assign(lhs=e, rhs=Add(d, c)), + ] + f = Func(name="_f", args=[a, b, c], stmts=stmts, ret=e) + return f + + result = toy_ir.parse_func(source_code) + expected = _expected() + result.eq_s(expected, assert_mode=True) diff --git a/tests/python/test_printer_ir_printer.py b/tests/python/test_printer_ir_printer.py index 61f04bb8..d7f8057b 100644 --- a/tests/python/test_printer_ir_printer.py +++ b/tests/python/test_printer_ir_printer.py @@ -1,76 +1,5 @@ -import mlc.dataclasses as mlcd import mlc.printer as mlcp -from mlc.printer import ast as mlt - - -@mlcd.py_class -class Expr(mlcd.PyClass): ... - - -@mlcd.py_class -class Stmt(mlcd.PyClass): ... - - -@mlcd.py_class -class Var(Expr): - name: str - - def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node: - if not printer.var_is_defined(obj=self): - printer.var_def(obj=self, frame=printer.frames[-1], name=self.name) - ret = printer.var_get(obj=self) - assert isinstance(ret, mlt.Id) - return ret - - -@mlcd.py_class -class Add(Expr): - lhs: Expr - rhs: Expr - - def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node: - lhs: mlt.Expr = printer(obj=self.lhs, path=path["a"]) - rhs: mlt.Expr = printer(obj=self.rhs, path=path["b"]) - return lhs + rhs - - -@mlcd.py_class -class Assign(Stmt): - lhs: Var - rhs: Expr - - def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node: - rhs: mlt.Expr = printer(obj=self.rhs, path=path["b"]) - printer.var_def(obj=self.lhs, frame=printer.frames[-1], name=self.lhs.name) - lhs: mlt.Expr = printer(obj=self.lhs, path=path["a"]) - return mlt.Assign(lhs=lhs, rhs=rhs) - - -@mlcd.py_class -class Func(mlcd.PyClass): - name: str - args: list[Var] - stmts: list[Stmt] - ret: Var - - def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node: - with printer.with_frame(mlcp.DefaultFrame()): - for arg in self.args: - printer.var_def(obj=arg, frame=printer.frames[-1], name=arg.name) - args: list[mlt.Expr] = [ - printer(obj=arg, path=path["args"][i]) for i, arg in enumerate(self.args) - ] - stmts: list[mlt.Expr] = [ - printer(obj=stmt, path=path["stmts"][i]) for i, stmt in enumerate(self.stmts) - ] - ret_stmt = mlt.Return(printer(obj=self.ret, path=path["ret"])) - return mlt.Function( - name=mlt.Id(self.name), - args=[mlt.Assign(lhs=arg, rhs=None) for arg in args], - decorators=[], - return_type=None, - body=[*stmts, ret_stmt], - ) +from mlc.testing.toy_ir import Add, Assign, Func, Var def test_var_print() -> None: