Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow calling a tensor of functions #196

Merged
merged 46 commits into from
May 14, 2024
Merged

Conversation

croyzor
Copy link
Collaborator

@croyzor croyzor commented Apr 16, 2024

Allow calling functions as if using tensor product of operators, with tuple syntax: (f, g)(a, b, c). This should work whenever all of the tuple elements are themselves callable, and otherwise raise an error.

  • Calls of tuples compile down to a TensorCall which packs up the LocalCalls and GlobalCalls to the functions in the tensor product
  • Some parts of typechecking have been generalised to allow more arguments than the function can process and return the leftovers. I did this instead of being concrete about arg numbers in order to facilitate use of the * operator for chaining calls like (f, g)(*h(a, b)) but there might be an easier way.

Resolves #141

Update: This no longer adds any new nodes or types -- it just uses tuples and adds a couple of helper functions to deal with them

Base automatically changed from feat/defs to main April 16, 2024 10:38
@croyzor croyzor force-pushed the feat/tensor-call-only branch from 0ec5385 to 0943e68 Compare April 16, 2024 11:02
@croyzor croyzor requested a review from mark-koch April 16, 2024 11:15
@croyzor croyzor marked this pull request as ready for review April 16, 2024 11:15
@mark-koch
Copy link
Collaborator

I'll do a more detailed review later, but here are some first thoughts:

I'm not sure if constructing a FunctionTensor during CFG construction is the best move since this will only work if you call the tuple directly. For example, I belive the following doesn't work, right?

@guppy
def foo() -> int:
    return 42

@guppy
def bar() -> bool:
    return True

@guppy
def baz() -> tuple[int, bool]:
    f = foo, bar
    return f()

In order to call "dynamic" function tuples, you'd have to wait until type checking a call f(args), infer the type of f, and check if the type is callable, i.e. if

  • it is function type, or
  • it is a type that implements __call__, or
  • it is tuple of callable types

If you made it such that callable types check against function types, then you wouldn't even need the tensor type since you could use Callable everywhere:

@guppy
def tensor_foo_bar() -> Callable[[], tuple[int, bool]]:
    return foo, bar

Copy link
Collaborator

@mark-koch mark-koch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is looking quite good overall 👍 My main concern is that the call checking code is becoming a bit hard to read, covering so many cases. Also, I'm not sure if nested tuples are always handled correctly?

One suggestion to consider is that we could get rid of the special logic that is used to check tuple literals and instead always rely on the general version using function_tensor_signature() and check_call() / synthesize_call()? This would simplify the logic substantially and give us nesting for free

guppylang/cfg/builder.py Outdated Show resolved Hide resolved
guppylang/checker/expr_checker.py Outdated Show resolved Hide resolved
guppylang/checker/expr_checker.py Outdated Show resolved Hide resolved
else:
elem_tys: list[FunctionType] = []
for i, (elt, elt_ty) in enumerate(zip(node.elts, ty.element_types)):
node.elts[i], fun_ty = self._synthesize(elt, allow_free_vars=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we want allow_free_vars=True.

Conceptionally, it would be nice since we could use the type we're checking against to infer some unsolved type variables in the tuple. Unfortunately, we have no easy way to report what we have learned back to the tuple element. I have some ideas for rewriting the type inference logic to allow stuff like this but we can't do it at the moment :(

The only thing that would work right now is first synthesising the element to figure out the number of inputs/outputs it expects and then re-checking the element against the expected function type acting only on these inputs/outputs. But I think doing this is probably not worth it...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking was that the vars would be resolved when we try to unify with the expected type below?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we can resolve them but we would need to go back to the place where variables were created and fill in the solution. There is currently no good way to do this

guppylang/checker/expr_checker.py Outdated Show resolved Hide resolved
guppylang/tys/ty.py Outdated Show resolved Hide resolved
guppylang/hugr/hugr.py Outdated Show resolved Hide resolved
tests/integration/test_tensor.py Show resolved Hide resolved
tests/integration/test_tensor.py Show resolved Hide resolved
guppylang/checker/expr_checker.py Outdated Show resolved Hide resolved
@croyzor croyzor requested a review from mark-koch May 9, 2024 11:18
@codecov-commenter
Copy link

codecov-commenter commented May 9, 2024

Codecov Report

Attention: Patch coverage is 94.20290% with 4 lines in your changes are missing coverage. Please review.

Project coverage is 90.70%. Comparing base (4f24a07) to head (0373dc6).

Files Patch % Lines
guppylang/checker/expr_checker.py 88.23% 2 Missing ⚠️
guppylang/compiler/expr_compiler.py 96.29% 1 Missing ⚠️
guppylang/tys/ty.py 94.73% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #196      +/-   ##
==========================================
+ Coverage   90.64%   90.70%   +0.06%     
==========================================
  Files          46       46              
  Lines        4724     4789      +65     
==========================================
+ Hits         4282     4344      +62     
- Misses        442      445       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines +269 to +271
if isinstance(func_ty, TupleType) and (
function_elements := parse_function_tensor(func_ty)
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to note: An empty function_elements lists also evaluate to False which means we can't call empty tuples. But I think that's fine, the expression ()() isn't that useful :D

guppylang/checker/expr_checker.py Outdated Show resolved Hide resolved
else:
elem_tys: list[FunctionType] = []
for i, (elt, elt_ty) in enumerate(zip(node.elts, ty.element_types)):
node.elts[i], fun_ty = self._synthesize(elt, allow_free_vars=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we can resolve them but we would need to go back to the place where variables were created and fill in the solution. There is currently no good way to do this

guppylang/checker/expr_checker.py Outdated Show resolved Hide resolved
guppylang/checker/expr_checker.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@mark-koch mark-koch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this looks a lot cleaner!

guppylang/compiler/expr_compiler.py Outdated Show resolved Hide resolved
guppylang/checker/expr_checker.py Outdated Show resolved Hide resolved
guppylang/nodes.py Outdated Show resolved Hide resolved
guppylang/checker/expr_checker.py Outdated Show resolved Hide resolved
guppylang/tys/ty.py Show resolved Hide resolved
guppylang/checker/expr_checker.py Outdated Show resolved Hide resolved
guppylang/checker/expr_checker.py Outdated Show resolved Hide resolved
tests/integration/test_tensor.py Show resolved Hide resolved
@croyzor croyzor requested a review from mark-koch May 14, 2024 13:23
Copy link
Collaborator

@mark-koch mark-koch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some final nits. The tests are looking good 👍

Feel free to merge afterwards

guppylang/checker/expr_checker.py Outdated Show resolved Hide resolved
Comment on lines 256 to 258
processed_args, subst, inst = check_call(
tensor_ty, node.args, tensor_ty.output, node.func, self.ctx
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if you do this instead you can get rid of the unification below

Suggested change
processed_args, subst, inst = check_call(
tensor_ty, node.args, tensor_ty.output, node.func, self.ctx
)
processed_args, subst, inst = check_call(
tensor_ty, node.args, ty, node, self.ctx
)
assert len(inst) == 0
return return with_loc(node, TensorCall(...)), subst

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also notice the updated location node.func to node. I think this should improve the error location in some of the golden tests

Comment on lines 206 to 209
if len(rets) == 1:
return rets[0]
else:
return self._pack_returns(rets, node.out_tys)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this len == 1 special case needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be redundant, duplicating the logic of pack_returns. I'll remove it

@croyzor croyzor merged commit af4fb07 into main May 14, 2024
3 checks passed
@croyzor croyzor deleted the feat/tensor-call-only branch May 14, 2024 15:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Syntax for quantum function tensor product
3 participants