Skip to content

Commit

Permalink
feat: implement FileLoader, add DictFileLoader, enable multi-file ope…
Browse files Browse the repository at this point in the history
…ration
  • Loading branch information
xen0n committed Jun 18, 2024
1 parent 144b428 commit 010d0a4
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 8 deletions.
3 changes: 2 additions & 1 deletion src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,8 @@ impl PySubGlobalsBuilder {
}

#[pyclass(module = "xingque", name = "FrozenModule")]
pub(crate) struct PyFrozenModule(FrozenModule);
#[derive(Clone)]
pub(crate) struct PyFrozenModule(pub(crate) FrozenModule);

impl From<FrozenModule> for PyFrozenModule {
fn from(value: FrozenModule) -> Self {
Expand Down
98 changes: 93 additions & 5 deletions src/eval.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use std::collections::HashMap;

use anyhow::anyhow;
use pyo3::exceptions::PyRuntimeError;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyTuple};
use starlark::environment::Module;
use starlark::eval::Evaluator;
use starlark::environment::{FrozenModule, Module};
use starlark::eval::{Evaluator, FileLoader};

use crate::codemap::PyFileSpan;
use crate::environment::{PyGlobals, PyModule};
use crate::environment::{PyFrozenModule, PyGlobals, PyModule};
use crate::syntax::PyAstModule;
use crate::{py2sl, sl2py};

Expand All @@ -17,6 +19,7 @@ pub(crate) struct PyEvaluator(
Evaluator<'static, 'static>,
// this reference is necessary for memory safety
#[allow(dead_code)] Py<PyModule>,
PyObjectFileLoader,
);

impl PyEvaluator {
Expand All @@ -25,7 +28,11 @@ impl PyEvaluator {
let module = module.borrow();
let module = module.inner()?;
let module: &'static Module = unsafe { ::core::mem::transmute(module) };
Ok(Self(Evaluator::new(module), module_ref))
Ok(Self(
Evaluator::new(module),
module_ref,
PyObjectFileLoader::default(),
))
}

fn ensure_module_available(&self, py: Python) -> PyResult<()> {
Expand Down Expand Up @@ -83,7 +90,17 @@ impl PyEvaluator {
Ok(())
}

// TODO: set_loader
fn set_loader(&mut self, py: Python, loader: &Bound<'_, PyAny>) -> PyResult<()> {
self.ensure_module_available(py)?;
self.2.set(loader.as_unbound().clone());
let ptr: &'_ dyn FileLoader = &self.2;
// Safety: actually the wrapper object and the evaluator are identically
// scoped
let ptr: &'static dyn FileLoader = unsafe { ::core::mem::transmute(ptr) };
self.0.set_loader(ptr);
Ok(())
}

// TODO: enable_profile
// TODO: write_profile
// TODO: gen_profile
Expand Down Expand Up @@ -174,3 +191,74 @@ impl PyEvaluator {
}
}
}

// it would be good if https://github.com/PyO3/pyo3/issues/1190 is implemented
// so we could have stronger typing
// but currently duck-typing isn't bad anyway
// this is why we don't declare this as a pyclass right now
#[derive(Debug, Default)]
pub(crate) struct PyObjectFileLoader(Option<PyObject>);

impl PyObjectFileLoader {
fn set(&mut self, obj: PyObject) {
self.0 = Some(obj);
}
}

impl FileLoader for PyObjectFileLoader {
fn load(&self, path: &str) -> anyhow::Result<FrozenModule> {
if let Some(inner) = self.0.as_ref() {
Python::with_gil(|py| {
// first check if it's a PyDictFileLoader and forward to its impl
if let Ok(x) = inner.downcast_bound::<PyDictFileLoader>(py) {
return x.borrow().load(path);
}

// duck-typing
// call the wrapped PyObject's "load" method with the path
// and expect the return value to be exactly PyFrozenModule
let name = intern!(py, "load");
let args = PyTuple::new_bound(py, &[path]);
Ok(inner
.call_method_bound(py, name, args, None)?
.extract::<PyFrozenModule>(py)?
.0)
})
} else {
// this should never happen because we control the only place where
// this struct could possibly get instantiated, and a PyObject is
// guaranteed there (remember None is also a non-null PyObject)
unreachable!()
}
}
}

// a PyDict is wrapped here instead of the ReturnFileLoader (so we effectively
// don't wrap ReturnFileLoader but provide equivalent functionality that's
// idiomatic in Python), because unfortunately ReturnFileLoader has a lifetime
// parameter, but luckily it's basically just a reference to a HashMap and its
// logic is trivial.
#[pyclass(module = "xingque", name = "DictFileLoader")]
pub(crate) struct PyDictFileLoader(Py<PyDict>);

#[pymethods]
impl PyDictFileLoader {
#[new]
fn py_new(modules: Py<PyDict>) -> Self {
Self(modules)
}
}

impl FileLoader for PyDictFileLoader {
fn load(&self, path: &str) -> anyhow::Result<FrozenModule> {
let result: anyhow::Result<_> =
Python::with_gil(|py| match self.0.bind(py).get_item(path)? {
Some(v) => Ok(Some(v.extract::<PyFrozenModule>()?)),
None => Ok(None),
});
result?.map(|x| x.0).ok_or(anyhow!(
"DictFileLoader does not know the module `{}`",
path
))
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ fn xingque(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<environment::PyGlobalsBuilder>()?;
m.add_class::<environment::PyLibraryExtension>()?;
m.add_class::<environment::PyModule>()?;
m.add_class::<eval::PyDictFileLoader>()?;
m.add_class::<eval::PyEvaluator>()?;
m.add_class::<syntax::PyAstModule>()?;
m.add_class::<syntax::PyDialect>()?;
Expand Down
35 changes: 35 additions & 0 deletions tests/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,38 @@ def square(x):
# TODO: wrap the "function" type somehow
# sq = e.module.get('square')
# assert e.eval_function(sq, 12) == 144


def test_load_stmt():
"""The original starlark-rust example "Enable the `load` statement" ported
to Python."""

def get_source(file: str) -> str:
match file:
case "a.star":
return "a = 7"
case "b.star":
return "b = 6"
case _:
return """
load('a.star', 'a')
load('b.star', 'b')
ab = a * b
"""

def get_module(file: str) -> xingque.FrozenModule:
ast = xingque.AstModule.parse(file, get_source(file), xingque.Dialect.STANDARD)
modules = {}
for load in ast.loads:
modules[load.module_id] = get_module(load.module_id)
loader = xingque.DictFileLoader(modules)

globals = xingque.Globals.standard()
module = xingque.Module()
eval = xingque.Evaluator(module)
eval.set_loader(loader)
eval.eval_module(ast, globals)
return module.freeze()

ab = get_module("ab.star")
assert ab.get("ab") == 42
11 changes: 9 additions & 2 deletions xingque.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Iterable, Iterator, Self
from typing import Callable, Iterable, Iterator, Protocol, Self

# starlark::codemap

Expand Down Expand Up @@ -196,14 +196,21 @@ class Module:

# starlark::eval

class _FileLoader(Protocol):
def load(self, path: str) -> FrozenModule: ...

class DictFileLoader:
def __init__(self, modules: dict[str, FrozenModule]) -> None: ...
def load(self, path: str) -> FrozenModule: ...

class Evaluator:
def __init__(self, module: Module | None = None) -> None: ...
# TODO: disable_gc
def eval_statements(self, statements: AstModule) -> object: ...
def local_variables(self) -> dict[str, object]: ...
def verbose_gc(self) -> None: ...
def enable_static_typechecking(self, enable: bool) -> None: ...
# TODO: set_loader
def set_loader(self, loader: _FileLoader) -> None: ...
# TODO: enable_profile
# TODO: write_profile
# TODO: gen_profile
Expand Down

0 comments on commit 010d0a4

Please sign in to comment.