diff --git a/src/environment.rs b/src/environment.rs index 3fc4229..e791ea9 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -434,7 +434,8 @@ impl PySubGlobalsBuilder { } #[pyclass(module = "xingque", name = "FrozenModule")] -pub(crate) struct PyFrozenModule(FrozenModule); +#[derive(Clone)] +pub(crate) struct PyFrozenModule(pub(crate) FrozenModule); impl From for PyFrozenModule { fn from(value: FrozenModule) -> Self { diff --git a/src/eval.rs b/src/eval.rs index d1da174..88c5dfa 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -1,13 +1,15 @@ use std::collections::HashMap; +use anyhow::anyhow; use pyo3::exceptions::PyRuntimeError; +use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyTuple}; -use starlark::environment::Module; -use starlark::eval::Evaluator; +use starlark::environment::{FrozenModule, Module}; +use starlark::eval::{Evaluator, FileLoader}; use crate::codemap::PyFileSpan; -use crate::environment::{PyGlobals, PyModule}; +use crate::environment::{PyFrozenModule, PyGlobals, PyModule}; use crate::syntax::PyAstModule; use crate::{py2sl, sl2py}; @@ -17,6 +19,7 @@ pub(crate) struct PyEvaluator( Evaluator<'static, 'static>, // this reference is necessary for memory safety #[allow(dead_code)] Py, + PyObjectFileLoader, ); impl PyEvaluator { @@ -25,7 +28,11 @@ impl PyEvaluator { let module = module.borrow(); let module = module.inner()?; let module: &'static Module = unsafe { ::core::mem::transmute(module) }; - Ok(Self(Evaluator::new(module), module_ref)) + Ok(Self( + Evaluator::new(module), + module_ref, + PyObjectFileLoader::default(), + )) } fn ensure_module_available(&self, py: Python) -> PyResult<()> { @@ -83,7 +90,17 @@ impl PyEvaluator { Ok(()) } - // TODO: set_loader + fn set_loader(&mut self, py: Python, loader: &Bound<'_, PyAny>) -> PyResult<()> { + self.ensure_module_available(py)?; + self.2.set(loader.as_unbound().clone()); + let ptr: &'_ dyn FileLoader = &self.2; + // Safety: actually the wrapper object and the evaluator are identically + // scoped + let ptr: &'static dyn FileLoader = unsafe { ::core::mem::transmute(ptr) }; + self.0.set_loader(ptr); + Ok(()) + } + // TODO: enable_profile // TODO: write_profile // TODO: gen_profile @@ -174,3 +191,74 @@ impl PyEvaluator { } } } + +// it would be good if https://github.com/PyO3/pyo3/issues/1190 is implemented +// so we could have stronger typing +// but currently duck-typing isn't bad anyway +// this is why we don't declare this as a pyclass right now +#[derive(Debug, Default)] +pub(crate) struct PyObjectFileLoader(Option); + +impl PyObjectFileLoader { + fn set(&mut self, obj: PyObject) { + self.0 = Some(obj); + } +} + +impl FileLoader for PyObjectFileLoader { + fn load(&self, path: &str) -> anyhow::Result { + if let Some(inner) = self.0.as_ref() { + Python::with_gil(|py| { + // first check if it's a PyDictFileLoader and forward to its impl + if let Ok(x) = inner.downcast_bound::(py) { + return x.borrow().load(path); + } + + // duck-typing + // call the wrapped PyObject's "load" method with the path + // and expect the return value to be exactly PyFrozenModule + let name = intern!(py, "load"); + let args = PyTuple::new_bound(py, &[path]); + Ok(inner + .call_method_bound(py, name, args, None)? + .extract::(py)? + .0) + }) + } else { + // this should never happen because we control the only place where + // this struct could possibly get instantiated, and a PyObject is + // guaranteed there (remember None is also a non-null PyObject) + unreachable!() + } + } +} + +// a PyDict is wrapped here instead of the ReturnFileLoader (so we effectively +// don't wrap ReturnFileLoader but provide equivalent functionality that's +// idiomatic in Python), because unfortunately ReturnFileLoader has a lifetime +// parameter, but luckily it's basically just a reference to a HashMap and its +// logic is trivial. +#[pyclass(module = "xingque", name = "DictFileLoader")] +pub(crate) struct PyDictFileLoader(Py); + +#[pymethods] +impl PyDictFileLoader { + #[new] + fn py_new(modules: Py) -> Self { + Self(modules) + } +} + +impl FileLoader for PyDictFileLoader { + fn load(&self, path: &str) -> anyhow::Result { + let result: anyhow::Result<_> = + Python::with_gil(|py| match self.0.bind(py).get_item(path)? { + Some(v) => Ok(Some(v.extract::()?)), + None => Ok(None), + }); + result?.map(|x| x.0).ok_or(anyhow!( + "DictFileLoader does not know the module `{}`", + path + )) + } +} diff --git a/src/lib.rs b/src/lib.rs index 002f849..77b6af2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,6 +25,7 @@ fn xingque(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/tests/test_smoke.py b/tests/test_smoke.py index c53130a..6042504 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -23,3 +23,38 @@ def square(x): # TODO: wrap the "function" type somehow # sq = e.module.get('square') # assert e.eval_function(sq, 12) == 144 + + +def test_load_stmt(): + """The original starlark-rust example "Enable the `load` statement" ported + to Python.""" + + def get_source(file: str) -> str: + match file: + case "a.star": + return "a = 7" + case "b.star": + return "b = 6" + case _: + return """ +load('a.star', 'a') +load('b.star', 'b') +ab = a * b +""" + + def get_module(file: str) -> xingque.FrozenModule: + ast = xingque.AstModule.parse(file, get_source(file), xingque.Dialect.STANDARD) + modules = {} + for load in ast.loads: + modules[load.module_id] = get_module(load.module_id) + loader = xingque.DictFileLoader(modules) + + globals = xingque.Globals.standard() + module = xingque.Module() + eval = xingque.Evaluator(module) + eval.set_loader(loader) + eval.eval_module(ast, globals) + return module.freeze() + + ab = get_module("ab.star") + assert ab.get("ab") == 42 diff --git a/xingque.pyi b/xingque.pyi index 3ffb04b..9cda1c0 100644 --- a/xingque.pyi +++ b/xingque.pyi @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Iterator, Self +from typing import Callable, Iterable, Iterator, Protocol, Self # starlark::codemap @@ -196,6 +196,13 @@ class Module: # starlark::eval +class _FileLoader(Protocol): + def load(self, path: str) -> FrozenModule: ... + +class DictFileLoader: + def __init__(self, modules: dict[str, FrozenModule]) -> None: ... + def load(self, path: str) -> FrozenModule: ... + class Evaluator: def __init__(self, module: Module | None = None) -> None: ... # TODO: disable_gc @@ -203,7 +210,7 @@ class Evaluator: def local_variables(self) -> dict[str, object]: ... def verbose_gc(self) -> None: ... def enable_static_typechecking(self, enable: bool) -> None: ... - # TODO: set_loader + def set_loader(self, loader: _FileLoader) -> None: ... # TODO: enable_profile # TODO: write_profile # TODO: gen_profile