-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Allow calling a tensor of functions #196
Conversation
0ec5385
to
0943e68
Compare
I'll do a more detailed review later, but here are some first thoughts: I'm not sure if constructing a @guppy
def foo() -> int:
return 42
@guppy
def bar() -> bool:
return True
@guppy
def baz() -> tuple[int, bool]:
f = foo, bar
return f() In order to call "dynamic" function tuples, you'd have to wait until type checking a call
If you made it such that callable types check against function types, then you wouldn't even need the @guppy
def tensor_foo_bar() -> Callable[[], tuple[int, bool]]:
return foo, bar |
All of the logic for tuples of functions is handled when these things are *called*, so nothing needs to be done here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is looking quite good overall 👍 My main concern is that the call checking code is becoming a bit hard to read, covering so many cases. Also, I'm not sure if nested tuples are always handled correctly?
One suggestion to consider is that we could get rid of the special logic that is used to check tuple literals and instead always rely on the general version using function_tensor_signature()
and check_call()
/ synthesize_call()
? This would simplify the logic substantially and give us nesting for free
guppylang/checker/expr_checker.py
Outdated
else: | ||
elem_tys: list[FunctionType] = [] | ||
for i, (elt, elt_ty) in enumerate(zip(node.elts, ty.element_types)): | ||
node.elts[i], fun_ty = self._synthesize(elt, allow_free_vars=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if we want allow_free_vars=True
.
Conceptionally, it would be nice since we could use the type we're checking against to infer some unsolved type variables in the tuple. Unfortunately, we have no easy way to report what we have learned back to the tuple element. I have some ideas for rewriting the type inference logic to allow stuff like this but we can't do it at the moment :(
The only thing that would work right now is first synthesising the element to figure out the number of inputs/outputs it expects and then re-checking the element against the expected function type acting only on these inputs/outputs. But I think doing this is probably not worth it...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My thinking was that the vars would be resolved when we try to unify with the expected type below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we can resolve them but we would need to go back to the place where variables were created and fill in the solution. There is currently no good way to do this
{} is falsey...
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #196 +/- ##
==========================================
+ Coverage 90.64% 90.70% +0.06%
==========================================
Files 46 46
Lines 4724 4789 +65
==========================================
+ Hits 4282 4344 +62
- Misses 442 445 +3 ☔ View full report in Codecov by Sentry. |
if isinstance(func_ty, TupleType) and ( | ||
function_elements := parse_function_tensor(func_ty) | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing to note: An empty function_elements
lists also evaluate to False
which means we can't call empty tuples. But I think that's fine, the expression ()()
isn't that useful :D
guppylang/checker/expr_checker.py
Outdated
else: | ||
elem_tys: list[FunctionType] = [] | ||
for i, (elt, elt_ty) in enumerate(zip(node.elts, ty.element_types)): | ||
node.elts[i], fun_ty = self._synthesize(elt, allow_free_vars=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we can resolve them but we would need to go back to the place where variables were created and fill in the solution. There is currently no good way to do this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, this looks a lot cleaner!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some final nits. The tests are looking good 👍
Feel free to merge afterwards
guppylang/checker/expr_checker.py
Outdated
processed_args, subst, inst = check_call( | ||
tensor_ty, node.args, tensor_ty.output, node.func, self.ctx | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if you do this instead you can get rid of the unification below
processed_args, subst, inst = check_call( | |
tensor_ty, node.args, tensor_ty.output, node.func, self.ctx | |
) | |
processed_args, subst, inst = check_call( | |
tensor_ty, node.args, ty, node, self.ctx | |
) | |
assert len(inst) == 0 | |
return return with_loc(node, TensorCall(...)), subst |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also notice the updated location node.func
to node
. I think this should improve the error location in some of the golden tests
guppylang/compiler/expr_compiler.py
Outdated
if len(rets) == 1: | ||
return rets[0] | ||
else: | ||
return self._pack_returns(rets, node.out_tys) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this len == 1
special case needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be redundant, duplicating the logic of pack_returns
. I'll remove it
Co-authored-by: Mark Koch <[email protected]>
Allow calling functions as if using tensor product of operators, with tuple syntax:
(f, g)(a, b, c)
. This should work whenever all of the tuple elements are themselves callable, and otherwise raise an error.TensorCall
which packs up theLocalCall
s andGlobalCall
s to the functions in the tensor product*
operator for chaining calls like(f, g)(*h(a, b))
but there might be an easier way.Resolves #141
Update: This no longer adds any new nodes or types -- it just uses tuples and adds a couple of helper functions to deal with them