-
Notifications
You must be signed in to change notification settings - Fork 126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Build with Local Function Directly #793
base: main
Are you sure you want to change the base?
Conversation
@swapdewalkar can you motivate this with an example please? I'm interested, curious to hear where/when you'd use it. |
@skrawcz I saw the requirement here to get all fns that are defined in the local module and insert them into a driver.
instead of
Let me know if I understood something different from ticket. |
|
||
def test_driver_with_local_modules() -> None: | ||
dr = Builder().with_local_modules().build() | ||
assert isinstance(dr.graph_executor, DefaultGraphExecutor) | ||
assert __name__ == list(dr.graph_modules)[0].__name__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this test probably needs to pull in a function and exercise the driver.
e.g. define
def a(b:int) -> int:
return b * 2
Then we should execute the driver to get "a".
That should work right?
@swapdewalkar cool makes sense -- want to comment on that ticket so I can assign it to you? Otherwise I think we just need a test to exercise this properly. |
@swapdewalkar thanks for the contribution! The Python import system behaves in strange ways at times, so it would be beneficial if you could add to the docstring of For instance, the current approach with A potential solution would be def with_local_modules(self) -> "Builder":
"""Adds the local modules to the modules list.
:return: self
"""
module = __import__("__main__")
self.modules.append(module)
return self This will get the module if __name__ == "__main__":
import __main__ But I invite you to investigate and set the appropriate tests! |
""" | ||
import inspect | ||
|
||
module = inspect.getmodule(inspect.stack()[1][0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See Thierry's comment. I think we can make this more robust.
@@ -1669,6 +1669,17 @@ def with_config(self, config: Dict[str, Any]) -> "Builder": | |||
self.config.update(config) | |||
return self | |||
|
|||
def with_local_modules(self) -> "Builder": | |||
"""Adds the local modules to the modules list. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add more caveats here.
E.g. If people use this, they could run into problems when using some of the adapters because this code will be imported under the module __main__
.
I will add more tests! |
Just waiting on more tests for this one :) |
--- PR TEMPLATE INSTRUCTIONS (1) ---
Address this Issue #685
Changes
How I tested this
Unit Tests
Notes
Checklist