Skip to content

Commit

Permalink
AST Parser
Browse files Browse the repository at this point in the history
  • Loading branch information
potatomashed committed Dec 28, 2024
1 parent 4434b87 commit c1a179c
Show file tree
Hide file tree
Showing 27 changed files with 793 additions and 1,262 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.15)

project(
mlc
VERSION 0.0.13
VERSION 0.0.14
DESCRIPTION "MLC-Python"
LANGUAGES C CXX
)
Expand Down
93 changes: 20 additions & 73 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,100 +121,47 @@ ValueError: Structural equality check failed at {root}.rhs.b: Inconsistent bindi

### :snake: Text Formats in Python

**IR Printer.** By defining an `__ir_print__` method, which converts an IR node to MLC's Python-style AST, MLC's `IRPrinter` handles variable scoping, renaming and syntax highlighting automatically for a text format based on Python syntax.
**Printer.** MLC converts an IR node to Python AST by looking up the `__ir_print__` method.

<details><summary>Defining Python-based text format on a toy IR using `__ir_print__`.</summary>
**[[Example](https://github.com/mlc-ai/mlc-python/blob/main/python/mlc/testing/toy_ir/ir.py)]**. Copy the toy IR definition to REPL and then create a `Func` node below:

```python
import mlc.dataclasses as mlcd
import mlc.printer as mlcp
from mlc.printer import ast as mlt

@mlcd.py_class
class Expr(mlcd.PyClass): ...

@mlcd.py_class
class Stmt(mlcd.PyClass): ...

@mlcd.py_class
class Var(Expr):
name: str
def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
if not printer.var_is_defined(obj=self):
printer.var_def(obj=self, frame=printer.frames[-1], name=self.name)
return printer.var_get(obj=self)

@mlcd.py_class
class Add(Expr):
lhs: Expr
rhs: Expr
def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
lhs: mlt.Expr = printer(obj=self.lhs, path=path["a"])
rhs: mlt.Expr = printer(obj=self.rhs, path=path["b"])
return lhs + rhs

@mlcd.py_class
class Assign(Stmt):
lhs: Var
rhs: Expr
def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
rhs: mlt.Expr = printer(obj=self.rhs, path=path["b"])
printer.var_def(obj=self.lhs, frame=printer.frames[-1], name=self.lhs.name)
lhs: mlt.Expr = printer(obj=self.lhs, path=path["a"])
return mlt.Assign(lhs=lhs, rhs=rhs)

@mlcd.py_class
class Func(mlcd.PyClass):
name: str
args: list[Var]
stmts: list[Stmt]
ret: Var
def __ir_print__(self, printer: mlcp.IRPrinter, path: mlcp.ObjectPath) -> mlt.Node:
with printer.with_frame(mlcp.DefaultFrame()):
for arg in self.args:
printer.var_def(obj=arg, frame=printer.frames[-1], name=arg.name)
args: list[mlt.Expr] = [printer(obj=arg, path=path["args"][i]) for i, arg in enumerate(self.args)]
stmts: list[mlt.Expr] = [printer(obj=stmt, path=path["stmts"][i]) for i, stmt in enumerate(self.stmts)]
ret_stmt = mlt.Return(printer(obj=self.ret, path=path["ret"]))
return mlt.Function(
name=mlt.Id(self.name),
args=[mlt.Assign(lhs=arg, rhs=None) for arg in args],
decorators=[],
return_type=None,
body=[*stmts, ret_stmt],
)

# An example IR:
a, b, c, d, e = Var("a"), Var("b"), Var("c"), Var("d"), Var("e")
f = Func(
name="f",
args=[a, b, c],
>>> a, b, c, d, e = Var("a"), Var("b"), Var("c"), Var("d"), Var("e")
>>> f = Func("f", [a, b, c],
stmts=[
Assign(lhs=d, rhs=Add(a, b)), # d = a + b
Assign(lhs=e, rhs=Add(d, c)), # e = d + c
],
ret=e,
)
ret=e)
```

</details>

Two printer APIs are provided for Python-based text format:
- `mlc.printer.to_python` that converts an IR fragment to Python text, and
- `mlc.printer.print_python` that further renders the text with proper syntax highlighting.
- Method `mlc.printer.to_python` converts an IR node to Python-based text;

```python
>>> print(mlcp.to_python(f)) # Stringify to Python
def f(a, b, c):
d = a + b
e = d + c
return e
```

- Method `mlc.printer.print_python` further renders the text with proper syntax highlighting. [[Screenshot](https://raw.githubusercontent.com/gist/potatomashed/5a9b20edbdde1b9a91a360baa6bce9ff/raw/3c68031eaba0620a93add270f8ad7ed2c8724a78/mlc-python-printer.svg)]

```python
>>> mlcp.print_python(f) # Syntax highlighting
```

**AST Parser.** MLC has a concise set of APIs for implementing parser with Python's AST module, including:
- Inspection API that obtains source code of a Python class or function and the variables they capture;
- Variable management APIs that help with proper scoping;
- AST fragment evaluation APIs;
- Error rendering APIs.

**[[Example](https://github.com/mlc-ai/mlc-python/blob/main/python/mlc/testing/toy_ir/parser.py)]**. With MLC APIs, a parser can be implemented with 100 lines of code for the Python text format above defined by `__ir_printer__`.

### :zap: Zero-Copy Interoperability with C++ Plugins

TBD
🚧 Under construction.

## :fuelpump: Development

Expand Down
38 changes: 19 additions & 19 deletions include/mlc/core/typing.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ struct AtomicType : public Type {
};

struct PtrTypeObj : protected MLCTypingPtr {
explicit PtrTypeObj(Type ty) : MLCTypingPtr{} { this->TyMutable() = ty; }
Type Ty() const { return Type(reinterpret_cast<const Ref<TypeObj> &>(this->MLCTypingPtr::ty)); }
explicit PtrTypeObj(Type ty) : MLCTypingPtr{} { this->TyMut() = ty; }
::mlc::Str __str__() const {
std::ostringstream os;
os << "Ptr[" << this->Ty() << "]";
Expand All @@ -138,20 +137,20 @@ struct PtrTypeObj : protected MLCTypingPtr {
MLC_DEF_STATIC_TYPE(PtrTypeObj, TypeObj, MLCTypeIndex::kMLCTypingPtr, "mlc.core.typing.PtrType");

private:
Type &TyMutable() { return reinterpret_cast<Type &>(this->MLCTypingPtr::ty); }
Type &TyMut() { return reinterpret_cast<Type &>(this->MLCTypingPtr::ty); }
Type Ty() const { return Type(reinterpret_cast<const Ref<TypeObj> &>(this->MLCTypingPtr::ty)); }
};

struct PtrType : public Type {
MLC_DEF_OBJ_REF(PtrType, PtrTypeObj, Type)
.StaticFn("__init__", InitOf<PtrTypeObj, Type>)
.MemFn("_ty", &PtrTypeObj::Ty)
._Field("ty", offsetof(MLCTypingPtr, ty), sizeof(MLCTypingPtr::ty), false, ParseType<Type>())
.MemFn("__str__", &PtrTypeObj::__str__)
.MemFn("__cxx_str__", &PtrTypeObj::__cxx_str__);
};

struct OptionalObj : protected MLCTypingOptional {
explicit OptionalObj(Type ty) : MLCTypingOptional{} { this->TyMutable() = ty; }
Type Ty() const { return Type(reinterpret_cast<const Ref<TypeObj> &>(this->MLCTypingOptional::ty)); }
::mlc::Str __str__() const {
std::ostringstream os;
os << this->Ty() << " | None";
Expand All @@ -167,19 +166,19 @@ struct OptionalObj : protected MLCTypingOptional {

private:
Type &TyMutable() { return reinterpret_cast<Type &>(this->MLCTypingOptional::ty); }
Type Ty() const { return Type(reinterpret_cast<const Ref<TypeObj> &>(this->MLCTypingOptional::ty)); }
};

struct Optional : public Type {
MLC_DEF_OBJ_REF(Optional, OptionalObj, Type)
.StaticFn("__init__", InitOf<OptionalObj, Type>)
.MemFn("_ty", &OptionalObj::Ty)
._Field("ty", offsetof(MLCTypingOptional, ty), sizeof(MLCTypingOptional::ty), false, ParseType<Type>())
.MemFn("__str__", &OptionalObj::__str__)
.MemFn("__cxx_str__", &OptionalObj::__cxx_str__);
};

struct ListObj : protected MLCTypingList {
explicit ListObj(Type ty) : MLCTypingList{} { this->TyMutable() = ty; }
Type Ty() const { return Type(reinterpret_cast<const Ref<TypeObj> &>(this->MLCTypingList::ty)); }
::mlc::Str __str__() const {
std::ostringstream os;
os << "list[" << this->Ty() << "]";
Expand All @@ -195,47 +194,48 @@ struct ListObj : protected MLCTypingList {

protected:
Type &TyMutable() { return reinterpret_cast<Type &>(this->MLCTypingList::ty); }
Type Ty() const { return Type(reinterpret_cast<const Ref<TypeObj> &>(this->MLCTypingList::ty)); }
};

struct List : public Type {
MLC_DEF_OBJ_REF(List, ListObj, Type)
.StaticFn("__init__", InitOf<ListObj, Type>)
.MemFn("_ty", &ListObj::Ty)
._Field("ty", offsetof(MLCTypingList, ty), sizeof(MLCTypingList::ty), false, ParseType<Type>())
.MemFn("__str__", &ListObj::__str__)
.MemFn("__cxx_str__", &ListObj::__cxx_str__);
};

struct DictObj : protected MLCTypingDict {
explicit DictObj(Type ty_k, Type ty_v) : MLCTypingDict{} {
this->TyMutableK() = ty_k;
this->TyMutableV() = ty_v;
this->TyKMut() = ty_k;
this->TyVMut() = ty_v;
}
Type key() const { return Type(reinterpret_cast<const Ref<TypeObj> &>(this->ty_k)); }
Type value() const { return Type(reinterpret_cast<const Ref<TypeObj> &>(this->ty_v)); }
::mlc::Str __str__() const {
std::ostringstream os;
os << "dict[" << this->key() << ", " << this->value() << "]";
os << "dict[" << this->TyK() << ", " << this->TyV() << "]";
return os.str();
}
::mlc::Str __cxx_str__() const {
::mlc::Str k_str = ::mlc::base::LibState::CxxStr(this->key());
::mlc::Str v_str = ::mlc::base::LibState::CxxStr(this->value());
::mlc::Str k_str = ::mlc::base::LibState::CxxStr(this->TyK());
::mlc::Str v_str = ::mlc::base::LibState::CxxStr(this->TyV());
std::ostringstream os;
os << "::mlc::Dict<" << k_str->data() << ", " << v_str->data() << ">";
return os.str();
}
MLC_DEF_STATIC_TYPE(DictObj, TypeObj, MLCTypeIndex::kMLCTypingDict, "mlc.core.typing.Dict");

protected:
Type &TyMutableK() { return reinterpret_cast<Type &>(this->ty_k); }
Type &TyMutableV() { return reinterpret_cast<Type &>(this->ty_v); }
Type &TyKMut() { return reinterpret_cast<Type &>(this->ty_k); }
Type &TyVMut() { return reinterpret_cast<Type &>(this->ty_v); }
Type TyK() const { return Type(reinterpret_cast<const Ref<TypeObj> &>(this->ty_k)); }
Type TyV() const { return Type(reinterpret_cast<const Ref<TypeObj> &>(this->ty_v)); }
};

struct Dict : public Type {
MLC_DEF_OBJ_REF(Dict, DictObj, Type)
.StaticFn("__init__", InitOf<DictObj, Type, Type>)
.MemFn("_key", &DictObj::key)
.MemFn("_value", &DictObj::value)
._Field("ty_k", offsetof(MLCTypingDict, ty_k), sizeof(MLCTypingDict::ty_k), false, ParseType<Type>())
._Field("ty_v", offsetof(MLCTypingDict, ty_v), sizeof(MLCTypingDict::ty_v), false, ParseType<Type>())
.MemFn("__str__", &DictObj::__str__)
.MemFn("__cxx_str__", &DictObj::__cxx_str__);
};
Expand Down
7 changes: 7 additions & 0 deletions include/mlc/core/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ struct ReflectionHelper {
return *this;
}

inline ReflectionHelper &_Field(const char *name, int64_t field_offset, int32_t num_bytes, bool frozen, Any ty) {
this->any_pool.push_back(ty);
int32_t index = static_cast<int32_t>(this->fields.size());
this->fields.emplace_back(MLCTypeField{name, index, field_offset, num_bytes, frozen, ty.v.v_obj});
return *this;
}

template <typename Callable> inline ReflectionHelper &MemFn(const char *name, Callable &&method) {
MLCTypeMethod m = this->PrepareMethod(name, std::forward<Callable>(method));
m.kind = kMemFn;
Expand Down
13 changes: 7 additions & 6 deletions include/mlc/printer/ir_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct IRPrinterObj : public Object {

bool VarIsDefined(const ObjectRef &obj) { return obj2info->count(obj) > 0; }

Id VarDef(const ObjectRef &obj, const ObjectRef &frame, Str name_hint) {
Id VarDef(Str name_hint, const ObjectRef &obj, const Optional<ObjectRef> &frame) {
if (auto it = obj2info.find(obj); it != obj2info.end()) {
Optional<Str> name = (*it).second->name;
return Id(name.value());
Expand All @@ -66,18 +66,19 @@ struct IRPrinterObj : public Object {
name = name_hint.ToStdString() + '_' + std::to_string(i);
}
defined_names->Set(name, 1);
this->_VarDef(obj, frame, VarInfo(name, Func([name]() { return Id(name); })));
this->_VarDef(VarInfo(name, Func([name]() { return Id(name); })), obj, frame);
return Id(name);
}

void VarDefNoName(const ObjectRef &obj, const ObjectRef &frame, const Func &creator) {
void VarDefNoName(const Func &creator, const ObjectRef &obj, const Optional<ObjectRef> &frame) {
if (obj2info.count(obj) > 0) {
MLC_THROW(KeyError) << "Variable already defined: " << obj;
}
this->_VarDef(obj, frame, VarInfo(mlc::Null, creator));
this->_VarDef(VarInfo(mlc::Null, creator), obj, frame);
}

void _VarDef(const ObjectRef &obj, const ObjectRef &frame, VarInfo var_info) {
void _VarDef(VarInfo var_info, const ObjectRef &obj, const Optional<ObjectRef> &_frame) {
ObjectRef frame = _frame.defined() ? _frame.value() : this->frames.back().operator ObjectRef();
obj2info->Set(obj, var_info);
auto it = frame_vars.find(frame);
if (it == frame_vars.end()) {
Expand All @@ -99,7 +100,7 @@ struct IRPrinterObj : public Object {
obj2info.erase(it);
}

Optional<Id> VarGet(const ObjectRef &obj) {
Optional<Expr> VarGet(const ObjectRef &obj) {
auto it = obj2info.find(obj);
if (it == obj2info.end()) {
return Null;
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
[project]
name = "mlc-python"
version = "0.0.13"
version = "0.0.14"
dependencies = [
'numpy >= 1.22',
"ml-dtypes >= 0.1",
'ml-dtypes >= 0.1',
'Pygments>=2.4.0',
'colorama',
'setuptools ; platform_system == "Windows"',
]
description = ""
Expand Down
2 changes: 1 addition & 1 deletion python/mlc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import _cython, ast, cc, dataclasses, printer
from . import _cython, cc, dataclasses, parser, printer
from ._cython import Ptr, Str
from .core import DataType, Device, Dict, Error, Func, List, Object, ObjectPath, typing
from .dataclasses import PyClass, c_class, py_class
8 changes: 6 additions & 2 deletions python/mlc/_cython/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1360,12 +1360,16 @@ def make_mlc_init(list fields):
cdef tuple setters = _setters
cdef int32_t num_args = len(args)
cdef int32_t i = 0
cdef object e = None
assert num_args == len(setters)
while i < num_args:
try:
setters[i](self, args[i])
except Exception as e: # no-cython-lint
raise ValueError(f"Failed to set field `{fields[i].name}`: {str(e)}. Got: {args[i]}")
except Exception as _e: # no-cython-lint
e = ValueError(f"Failed to set field `{fields[i].name}`: {str(_e)}. Got: {args[i]}")
e = e.with_traceback(_e.__traceback__)
if e is not None:
raise e
i += 1

return _mlc_init
Expand Down
4 changes: 0 additions & 4 deletions python/mlc/ast/__init__.py

This file was deleted.

Loading

0 comments on commit c1a179c

Please sign in to comment.