diff --git a/hamilton/driver.py b/hamilton/driver.py index 711384024..8a771f8ba 100644 --- a/hamilton/driver.py +++ b/hamilton/driver.py @@ -1669,6 +1669,17 @@ def with_config(self, config: Dict[str, Any]) -> "Builder": self.config.update(config) return self + def with_local_modules(self) -> "Builder": + """Adds the local modules to the modules list. + + :return: self + """ + import inspect + + module = inspect.getmodule(inspect.stack()[1][0]) + self.modules.append(module) + return self + def with_modules(self, *modules: ModuleType) -> "Builder": """Adds the specified modules to the modules list. This can be called multiple times -- later calls will take precedence. diff --git a/tests/test_driver_local_modules.py b/tests/test_driver_local_modules.py new file mode 100644 index 000000000..8e5e44c76 --- /dev/null +++ b/tests/test_driver_local_modules.py @@ -0,0 +1,7 @@ +from hamilton.driver import Builder, DefaultGraphExecutor + + +def test_driver_with_local_modules() -> None: + dr = Builder().with_local_modules().build() + assert isinstance(dr.graph_executor, DefaultGraphExecutor) + assert __name__ == list(dr.graph_modules)[0].__name__