Skip to content

Commit

Permalink
Fix incorrect call return for python device kernels (#2319)
Browse files Browse the repository at this point in the history
* Fix incorrect call result in python

Signed-off-by: Anna Gringauze <[email protected]>

* Temp

Signed-off-by: Anna Gringauze <[email protected]>

* Add more tests

Signed-off-by: Anna Gringauze <[email protected]>

* DCO Remediation Commit for Anna Gringauze <[email protected]>

I, Anna Gringauze <[email protected]>, hereby add my Signed-off-by to this commit: ce0fd32

Signed-off-by: Anna Gringauze <[email protected]>

* Cleanup

Signed-off-by: Anna Gringauze <[email protected]>

* Address CR comments

Signed-off-by: Anna Gringauze <[email protected]>

---------

Signed-off-by: Anna Gringauze <[email protected]>
  • Loading branch information
annagrin authored Nov 3, 2024
1 parent 6a47ffc commit 2c162ab
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 2 deletions.
24 changes: 22 additions & 2 deletions python/cudaq/kernel/ast_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def __init__(self, capturedDataStorage: CapturedDataStorage, **kwargs):
self.buildingEntryPoint = False
self.inForBodyStack = deque()
self.inIfStmtBlockStack = deque()
self.currentAssignVariableName = None
self.walkingReturnNode = False
self.controlNegations = []
self.subscriptPushPointerValue = False
self.verbose = 'verbose' in kwargs and kwargs['verbose']
Expand Down Expand Up @@ -900,6 +902,10 @@ def visit_FunctionDef(self, node):
self.mlirTypeFromAnnotation(arg.annotation)
for arg in node.args.args
]
parentResultType = self.knownResultType
if node.returns != None:
self.knownResultType = self.mlirTypeFromAnnotation(node.returns)

# Get the argument names
argNames = [arg.arg for arg in node.args.args]

Expand Down Expand Up @@ -981,6 +987,8 @@ def isQuantumTy(ty):
self.symbolTable.clear()
self.valueStack.clear()

self.knownResultType = parentResultType

def visit_Expr(self, node):
"""
Implement `ast.Expr` visitation to screen out all
Expand Down Expand Up @@ -1713,7 +1721,7 @@ def bodyBuilder(iterVal):
# If `registerName` is None, then we know that we
# are not assigning this measure result to anything
# so we therefore should not push it on the stack
pushResultToStack = registerName != None
pushResultToStack = registerName != None or self.walkingReturnNode

# By default we set the `register_name` for the measurement
# to the assigned variable name (if there is one). But
Expand Down Expand Up @@ -1870,7 +1878,11 @@ def bodyBuilder(iterVal):
values = [self.popValue() for _ in node.args]
values.reverse()
values = [self.ifPointerThenLoad(v) for v in values]
func.CallOp(otherKernel, values)
if len(fType.results) == 0:
func.CallOp(otherKernel, values)
else:
result = func.CallOp(otherKernel, values).result
self.pushValue(result)
return

elif node.func.id in self.symbolTable:
Expand Down Expand Up @@ -2162,6 +2174,11 @@ def bodyBuilder(iterVal):
node)

if node.func.attr == 'qvector':
if len(self.valueStack) == 0:
self.emitFatalError(
'qvector does not have default constructor. Init from size or existing state.',
node)

valueOrPtr = self.popValue()
initializerTy = valueOrPtr.type

Expand Down Expand Up @@ -3519,7 +3536,9 @@ def visit_Return(self, node):
if node.value == None:
return

self.walkingReturnNode = True
self.visit(node.value)
self.walkingReturnNode = False

if len(self.valueStack) == 0:
return
Expand All @@ -3537,6 +3556,7 @@ def visit_Return(self, node):
byteWidth = 16 if ComplexType.isinstance(eleTy) else 8
eleSize = self.getConstantInt(byteWidth)
dynSize = cc.StdvecSizeOp(self.getIntegerType(), result).result
resBuf = cc.CastOp(ptrTy, resBuf)
heapCopy = func.CallOp([ptrTy], symName,
[resBuf, dynSize, eleSize]).result
res = cc.StdvecInitOp(result.type, heapCopy, dynSize).result
Expand Down
94 changes: 94 additions & 0 deletions python/tests/kernel/test_kernel_call_return.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# ============================================================================ #
# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #

import cudaq


def test_call_with_callee_return_bool():

@cudaq.kernel
def bar(qubits: cudaq.qview) -> bool:
x(qubits)
return False

@cudaq.kernel
def foo(n: int):
qubits = cudaq.qvector(n)
bar(qubits)

counts = cudaq.sample(foo, 3)
assert "111" in counts and len(counts) == 1


def test_call_with_return_bool():

@cudaq.kernel()
def callee(q: cudaq.qubit) -> bool:
x(q)
m = mz(q)
return m

@cudaq.kernel()
def caller() -> bool:
q = cudaq.qubit()
return callee(q)

result = caller()
assert result == True or result == False

counts = cudaq.sample(caller)
assert '1' in counts and len(counts) == 1


def test_call_with_return_bool2():
from dataclasses import dataclass

@dataclass
class patch:
data: cudaq.qview
ancx: cudaq.qview
ancz: cudaq.qview

@cudaq.kernel()
def stabilizer(logicalQubit: patch, x_stabilizers: list[int],
z_stabilizers: list[int]) -> bool:
for xi in range(len(logicalQubit.ancx)):
for di in range(len(logicalQubit.data)):
if x_stabilizers[xi * len(logicalQubit.data) + di] == 1:
x.ctrl(logicalQubit.ancx[xi], logicalQubit.data[di])

h(logicalQubit.ancx)
for zi in range(len(logicalQubit.ancz)):
for di in range(len(logicalQubit.data)):
if z_stabilizers[zi * len(logicalQubit.data) + di] == 1:
x.ctrl(logicalQubit.data[di], logicalQubit.ancz[zi])

results = mz(logicalQubit.ancx, logicalQubit.ancz)

reset(logicalQubit.ancx)
reset(logicalQubit.ancz)
#TODO: support returning lists
#Issue: https://github.com/NVIDIA/cuda-quantum/issues/2336
return results[3]

@cudaq.kernel()
def run() -> bool:
q = cudaq.qvector(2)
x(q[0])
r = cudaq.qvector(2)
s = cudaq.qvector(2)
p = patch(q, r, s)

return stabilizer(p, [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1])

result = run()
assert result == True or result == False

sample_result = cudaq.sample(run)
counts = sample_result.get_register_counts("results")
assert len(counts) == 4

0 comments on commit 2c162ab

Please sign in to comment.