Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a Program from an OpenQASM3 string #59

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
348 changes: 348 additions & 0 deletions oqpy/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from typing import Any, Iterable, Iterator, Optional

from openpulse import ast
from openpulse.parser import parse
from openpulse.printer import dumps
from openqasm3.visitor import QASMVisitor

Expand Down Expand Up @@ -307,6 +308,14 @@ def to_ast(
MergeCalStatementsPass().visit(prog)
return prog

@staticmethod
def from_qasm(source: str) -> Program:
"""Build an OQPy program by parsing OpenQASM text."""
prog = Program()
oqasm_ast = parse(source)
ProgramBuilder().visit(oqasm_ast, prog)
return prog

def to_qasm(
self,
encal: bool = False,
Expand Down Expand Up @@ -602,3 +611,342 @@ def process_statement_list(
new_list.append(ast.CalibrationStatement(body=cal_stmts))

return new_list


class ProgramBuilder(QASMVisitor[Program]):
"""AST Transformer class that modifies the tree created from parsing openqasm input text.

It separates:
- extern declarations and stores them in Program().externs.
- subroutines and stores them in Program().subroutines
- defcals and stores in Program().defcals
It also creates the corresponding OQpy variables everytime it encounters a classical
or pulse type.
"""

inside_def_block: bool = False

TIME_UNIT_TO_EXP = {"ns": 3, "us": 2, "ms": 1, "s": 0}

def generic_visit(self, node: ast.QASMNode, context: Program | None = None) -> dict[str, Any]:
res_value: dict[str, Any] = {}
for field, old_value in node.__dict__.items():
if isinstance(old_value, list):
new_values = []
res_value[field] = []
for value in old_value:
if isinstance(value, ast.QASMNode):
res = self.visit(value, context) if context else self.visit(value)
value = res["node"]
if "value" in res:
res_value[field].append(res["value"])
if value is None:
continue
elif not isinstance(value, ast.QASMNode):
new_values.extend(value)
continue
elif isinstance(value, list):
my_table = []
for idx, element in enumerate(value):
res = self.visit(element, context) if context else self.visit(element)
value[idx] = res["node"]
if "value" in res:
my_table.append(res["value"])
new_values.append(value)
res_value[field].append(my_table)
continue
else:
raise TypeError(f"Got {type(value)} for {field}")
new_values.append(value)
old_value[:] = new_values
elif isinstance(old_value, ast.QASMNode):
res = self.visit(old_value, context) if context else self.visit(old_value)
if isinstance(res, ast.QASMNode):
new_node = res
else:
new_node = res["node"]
res_value[field] = res["value"]
if new_node is None:
delattr(node, field)
else:
setattr(node, field, new_node)
return {"node": node, "value": res_value if res_value is not {} else None}

def visit(self, node: ast.QASMNode, context: Optional[Program] = None) -> dict[str, Any]:
"""Visit a node."""
var: Var | None = None
if hasattr(node, "span"):
node.span = None
if not isinstance(node, ast.ClassicalDeclaration) and hasattr(node, "type"):
if hasattr(node, "identifier"):
var = self.create_oqpy_var(node.type, node.identifier.name, needs_declaration=False)
elif hasattr(node, "name"):
var = self.create_oqpy_var(node.type, node.name.name, needs_declaration=False)
if context is not None and var is not None:
context._add_var(var)

method = "visit_" + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
# The visitor method may not have the context argument.
if context:
res = visitor(node, context)
else:
res = visitor(node)
if isinstance(res, ast.QASMNode):
return {"node": res}
else:
return res

def visit_Program(self, node: ast.Program, context: Program) -> ast.QASMNode:
node = self.generic_visit(node, context)["node"]

context.version = node.version

for statement in node.statements:
context._add_statement(statement)
return node

def visit_CalibrationGrammarDeclaration(
self, node: ast.CalibrationGrammarDeclaration, context: Program
) -> dict[str, Any]:
return {"node": None}

def visit_ExternDeclaration(
self, node: ast.ExternDeclaration, context: Program
) -> dict[str, Any]:
node = self.generic_visit(node, context)["node"] # Clear spans first
context.externs[node.name.name] = node
return {"node": None}

def visit_ClassicalDeclaration(
self, node: ast.ClassicalDeclaration, context: Program
) -> dict[str, Any]:
var: Var | None

res = self.generic_visit(node, context)
node = res["node"]

if "init_expression" in res["value"]:
init_expression_value = res["value"]["init_expression"]
else:
init_expression_value = None
var = self.create_oqpy_var(node.type, node.identifier.name, init_expression_value)
if var is not None:
context._mark_var_declared(var)
return {"node": node}

def visit_CalibrationDefinition(
self, node: ast.CalibrationDefinition, context: Program
) -> dict[str, Any]:
self.inside_def_block = True
context._add_defcal(
[ident.name for ident in node.qubits],
node.name.name,
[dumps(a) for a in node.arguments],
node,
)
visited_node = self.generic_visit(node, context)["node"]
self.inside_def_block = False
return {"node": visited_node}

def visit_SubroutineDefinition(
self, node: ast.SubroutineDefinition, context: Program
) -> dict[str, Any]:
self.inside_def_block = True
visited_node = self.generic_visit(node, context)["node"]
self.inside_def_block = False
context._add_subroutine(visited_node.name.name, visited_node)
return {"node": None}

def create_oqpy_var(
self,
node_type: ast.ClassicalType,
name: str,
init_expression: Any | None = None,
needs_declaration: bool = True,
) -> Var | None:
if self.inside_def_block:
return None

var: Var | None = None
if isinstance(node_type, ast.BitType):
var = classical_types.BitVar(
init_expression=init_expression, name=name, needs_declaration=needs_declaration
)
elif isinstance(node_type, ast.BoolType):
var = classical_types.BoolVar(
init_expression=init_expression, name=name, needs_declaration=needs_declaration
)
elif isinstance(node_type, ast.IntType):
var = classical_types.IntVar(
init_expression=init_expression, name=name, needs_declaration=needs_declaration
)
elif isinstance(node_type, ast.UintType):
var = classical_types.UintVar(
init_expression=init_expression, name=name, needs_declaration=needs_declaration
)
elif isinstance(node_type, ast.FloatType):
var = classical_types.FloatVar(
init_expression=init_expression, name=name, needs_declaration=needs_declaration
)
elif isinstance(node_type, ast.AngleType):
var = classical_types.AngleVar(
init_expression=init_expression, name=name, needs_declaration=needs_declaration
)
elif isinstance(node_type, ast.ComplexType):
var = classical_types.ComplexVar(
init_expression=init_expression,
name=name,
base_type=node_type.base_type,
needs_declaration=needs_declaration,
)
elif isinstance(node_type, ast.DurationType):
value = None
if isinstance(init_expression, ast.DurationLiteral):
if init_expression.unit.name not in self.TIME_UNIT_TO_EXP:
raise ValueError(
f"Unexpected duration specified: {init_expression.unit.name}:{init_expression.unit.value}"
)
multiplier = 10 ** (-3 * self.TIME_UNIT_TO_EXP[init_expression.unit.name])
value = multiplier * init_expression.value
var = classical_types.DurationVar(
init_expression=value, name=name, needs_declaration=needs_declaration
)
elif isinstance(node_type, ast.StretchType):
var = classical_types.StretchVar(
init_expression=init_expression, name=name, needs_declaration=needs_declaration
)
elif isinstance(node_type, ast.FrameType):
if isinstance(init_expression, dict):
var = FrameVar(name=name, **init_expression)
else:
var = FrameVar(name=name)
elif isinstance(node_type, ast.PortType):
var = PortVar(name=name)
elif isinstance(node_type, ast.WaveformType):
var = WaveformVar(init_expression=init_expression, name=name)
else:
raise TypeError(f"Unsupported type {type(node_type)} was used in the OpenQASM program.")
return var

def visit_FunctionCall(self, node: ast.FunctionCall, context: Program) -> dict[str, Any]:
node = self.generic_visit(node, context)["node"]
if node.name.name == "newframe":
value = {
"port": node.arguments[0].name,
"frequency": node.arguments[1].value,
"phase": node.arguments[2].value,
}
else:
value = None
return {"node": node, "value": value}

def visit_BitstringLiteral(
self, node: ast.BitstringLiteral, context: Program
) -> dict[str, Any]:
value = bin(node.value)[2:]
if len(value) < node.width:
value = "0" * (node.width - len(value)) + value
return {"node": node, "value": value}

def visit_IntegerLiteral(self, node: ast.IntegerLiteral, context: Program) -> dict[str, Any]:
return {"node": node, "value": node.value}

def visit_FloatLiteral(self, node: ast.FloatLiteral, context: Program) -> dict[str, Any]:
return {"node": node, "value": node.value}

def visit_ImaginaryLiteral(
self, node: ast.ImaginaryLiteral, context: Program
) -> dict[str, Any]:
return {"node": node, "value": node.value * 1j}

def visit_BooleanLiteral(self, node: ast.BooleanLiteral, context: Program) -> dict[str, Any]:
return {"node": node, "value": True if node.value else False}

def visit_DurationLiteral(self, node: ast.DurationLiteral, context: Program) -> dict[str, Any]:
return {"node": node, "value": convert_float_to_duration(node.value * 1e-9)}

def visit_ArrayLiteral(self, node: ast.ArrayLiteral, context: Program) -> dict[str, Any]:
return {
"node": node,
"value": [self.generic_visit(n, context)["value"] for n in node.values],
}

def visit_Identifier(self, node: ast.Identifier, context: Program) -> dict[str, Any]:
if node.name in context.declared_vars:
value = context.declared_vars[node.name]
elif node.name in context.undeclared_vars:
value = context.undeclared_vars[node.name]
else:
value = node.name
return {"node": node, "value": value}

def visit_BinaryExpression(
self, node: ast.BinaryExpression, context: Program
) -> dict[str, Any]:
res = self.generic_visit(node, context)
node = res["node"]
lhs = res["value"]["lhs"]
rhs = res["value"]["rhs"]

# FIXME: pass the right type to ast_type
if isinstance(lhs, str):
lhs = classical_types.Identifier(lhs, ast.ClassicalType)
if isinstance(rhs, str):
rhs = classical_types.Identifier(rhs, ast.ClassicalType)

op = ast.BinaryOperator

result = None
if node.op == op["+"]:
result = lhs + rhs
elif node.op == op["-"]:
result = lhs - rhs
elif node.op == op["*"]:
result = lhs * rhs
elif node.op == op["/"]:
result = lhs / rhs
elif node.op == op["%"]:
result = lhs % rhs
elif node.op == op["**"]:
result = lhs**rhs
elif node.op == op[">"]:
result = lhs > rhs
elif node.op == op["<"]:
result = lhs < rhs
elif node.op == op[">="]:
result = lhs >= rhs
elif node.op == op["<="]:
result = lhs <= rhs
elif node.op == op["=="]:
result = lhs == rhs
elif node.op == op["!="]:
result = lhs != rhs
elif node.op == op["&&"]:
result = lhs and rhs
elif node.op == op["||"]:
result = lhs or rhs
elif node.op == op["|"]:
result = lhs | rhs
elif node.op == op["^"]:
result = lhs ^ rhs
elif node.op == op["&"]:
result = lhs & rhs
elif node.op == op["<<"]:
result = lhs << rhs
elif node.op == op[">>"]:
result = lhs >> rhs
return {"node": node, "value": result}

def visit_UnaryExpression(self, node: ast.UnaryExpression, context: Program) -> dict[str, Any]:
res = self.generic_visit(node, context)
node = res["node"]
exp = res["value"]["expression"]

if node.op == ast.UnaryOperator["-"]:
result = -1 * exp
elif node.op == ast.UnaryOperator["!"]:
result = not exp
elif node.op == ast.UnaryOperator["~"]:
result = ~exp
return {"node": node, "value": result}
Loading