Skip to content

Commit

Permalink
Add indexing and bitshift expressions (#12310)
Browse files Browse the repository at this point in the history
This adds the concepts of `Index` and the binary operations `SHIFT_LEFT`
and `SHIFT_RIGHT` to the `Expr` system, and threads them through all the
`ExprVisitor` nodes defined in Qiskit, including OQ3 and QPY
serialisation/deserialisation.  (The not-new OQ3 parser is managed
separately and doesn't have support for _any_ `Expr` nodes at the
moment).

Along with `Store`, this should close the gap between what Qiskit was
able to represent with dynamic circuits, and what was supported by
hardware with direct OpenQASM 3 submission, although since this remains
Qiskit's fairly low-level representation, it still was potentially more
ergonomic to use OpenQASM 3 strings.  This remains a general point for
improvement in the Qiskit API, however.
  • Loading branch information
jakelishman authored May 2, 2024
1 parent cadc6f1 commit d6c74c2
Show file tree
Hide file tree
Showing 20 changed files with 529 additions and 64 deletions.
24 changes: 22 additions & 2 deletions qiskit/circuit/classical/expr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
.. autoclass:: Var
:members: var, name
Similarly, literals used in comparison (such as integers) should be lifted to :class:`Value` nodes
Similarly, literals used in expressions (such as integers) should be lifted to :class:`Value` nodes
with associated types.
.. autoclass:: Value
Expand All @@ -62,6 +62,12 @@
:members: Op
:member-order: bysource
Bit-like types (unsigned integers) can be indexed by integer types, represented by :class:`Index`.
The result is a single bit. The resulting expression has an associated memory location (and so can
be used as an lvalue for :class:`.Store`, etc) if the target is also an lvalue.
.. autoclass:: Index
When constructing expressions, one must ensure that the types are valid for the operation.
Attempts to construct expressions with invalid types will raise a regular Python ``TypeError``.
Expand Down Expand Up @@ -122,6 +128,13 @@
.. autofunction:: less_equal
.. autofunction:: greater
.. autofunction:: greater_equal
.. autofunction:: shift_left
.. autofunction:: shift_right
You can index into unsigned integers and bit-likes using another unsigned integer of any width.
This includes in storing operations, if the target of the index is writeable.
.. autofunction:: index
Qiskit's legacy method for specifying equality conditions for use in conditionals is to use a
two-tuple of a :class:`.Clbit` or :class:`.ClassicalRegister` and an integer. This represents an
Expand Down Expand Up @@ -174,6 +187,7 @@
"Cast",
"Unary",
"Binary",
"Index",
"ExprVisitor",
"iter_vars",
"structurally_equivalent",
Expand All @@ -185,6 +199,8 @@
"bit_and",
"bit_or",
"bit_xor",
"shift_left",
"shift_right",
"logic_and",
"logic_or",
"equal",
Expand All @@ -193,10 +209,11 @@
"less_equal",
"greater",
"greater_equal",
"index",
"lift_legacy_condition",
]

from .expr import Expr, Var, Value, Cast, Unary, Binary
from .expr import Expr, Var, Value, Cast, Unary, Binary, Index
from .visitors import ExprVisitor, iter_vars, structurally_equivalent, is_lvalue
from .constructors import (
lift,
Expand All @@ -214,5 +231,8 @@
less_equal,
greater,
greater_equal,
shift_left,
shift_right,
index,
lift_legacy_condition,
)
85 changes: 84 additions & 1 deletion qiskit/circuit/classical/expr/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

import typing

from .expr import Expr, Var, Value, Unary, Binary, Cast
from .expr import Expr, Var, Value, Unary, Binary, Cast, Index
from ..types import CastKind, cast_kind
from .. import types

Expand Down Expand Up @@ -471,3 +471,86 @@ def greater_equal(left: typing.Any, right: typing.Any, /) -> Expr:
Uint(3))
"""
return _binary_relation(Binary.Op.GREATER_EQUAL, left, right)


def _shift_like(
op: Binary.Op, left: typing.Any, right: typing.Any, type: types.Type | None
) -> Expr:
if type is not None and type.kind is not types.Uint:
raise TypeError(f"type '{type}' is not a valid bitshift operand type")
if isinstance(left, Expr):
left = _coerce_lossless(left, type) if type is not None else left
else:
left = lift(left, type)
right = lift(right)
if left.type.kind != types.Uint or right.type.kind != types.Uint:
raise TypeError(f"invalid types for '{op}': '{left.type}' and '{right.type}'")
return Binary(op, left, right, left.type)


def shift_left(left: typing.Any, right: typing.Any, /, type: types.Type | None = None) -> Expr:
"""Create a 'bitshift left' expression node from the given two values, resolving any implicit
casts and lifting the values into :class:`Value` nodes if required.
If ``type`` is given, the ``left`` operand will be coerced to it (if possible).
Examples:
Shift the value of a standalone variable left by some amount::
>>> from qiskit.circuit.classical import expr, types
>>> a = expr.Var.new("a", types.Uint(8))
>>> expr.shift_left(a, 4)
Binary(Binary.Op.SHIFT_LEFT, \
Var(<UUID>, Uint(8), name='a'), \
Value(4, Uint(3)), \
Uint(8))
Shift an integer literal by a variable amount, coercing the type of the literal::
>>> expr.shift_left(3, a, types.Uint(16))
Binary(Binary.Op.SHIFT_LEFT, \
Value(3, Uint(16)), \
Var(<UUID>, Uint(8), name='a'), \
Uint(16))
"""
return _shift_like(Binary.Op.SHIFT_LEFT, left, right, type)


def shift_right(left: typing.Any, right: typing.Any, /, type: types.Type | None = None) -> Expr:
"""Create a 'bitshift right' expression node from the given values, resolving any implicit casts
and lifting the values into :class:`Value` nodes if required.
If ``type`` is given, the ``left`` operand will be coerced to it (if possible).
Examples:
Shift the value of a classical register right by some amount::
>>> from qiskit.circuit import ClassicalRegister
>>> from qiskit.circuit.classical import expr
>>> expr.shift_right(ClassicalRegister(8, "a"), 4)
Binary(Binary.Op.SHIFT_RIGHT, \
Var(ClassicalRegister(8, "a"), Uint(8)), \
Value(4, Uint(3)), \
Uint(8))
"""
return _shift_like(Binary.Op.SHIFT_RIGHT, left, right, type)


def index(target: typing.Any, index: typing.Any, /) -> Expr:
"""Index into the ``target`` with the given integer ``index``, lifting the values into
:class:`Value` nodes if required.
This can be used as the target of a :class:`.Store`, if the ``target`` is itself an lvalue.
Examples:
Index into a classical register with a literal::
>>> from qiskit.circuit import ClassicalRegister
>>> from qiskit.circuit.classical import expr
>>> expr.index(ClassicalRegister(8, "a"), 3)
Index(Var(ClassicalRegister(8, "a"), Uint(8)), Value(3, Uint(2)), Bool())
"""
target, index = lift(target), lift(index)
if target.type.kind is not types.Uint or index.type.kind is not types.Uint:
raise TypeError(f"invalid types for indexing: '{target.type}' and '{index.type}'")
return Index(target, index, types.Bool())
41 changes: 41 additions & 0 deletions qiskit/circuit/classical/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,11 @@ class Op(enum.Enum):
The binary mathematical relations :data:`EQUAL`, :data:`NOT_EQUAL`, :data:`LESS`,
:data:`LESS_EQUAL`, :data:`GREATER` and :data:`GREATER_EQUAL` take unsigned integers
(with an implicit cast to make them the same width), and return a Boolean.
The bitshift operations :data:`SHIFT_LEFT` and :data:`SHIFT_RIGHT` can take bit-like
container types (e.g. unsigned integers) as the left operand, and any integer type as the
right-hand operand. In all cases, the output bit width is the same as the input, and zeros
fill in the "exposed" spaces.
"""

# If adding opcodes, remember to add helper constructor functions in `constructors.py`
Expand Down Expand Up @@ -327,6 +332,10 @@ class Op(enum.Enum):
"""Numeric greater than. ``lhs > rhs``."""
GREATER_EQUAL = 11
"""Numeric greater than or equal to. ``lhs >= rhs``."""
SHIFT_LEFT = 12
"""Zero-padding bitshift to the left. ``lhs << rhs``."""
SHIFT_RIGHT = 13
"""Zero-padding bitshift to the right. ``lhs >> rhs``."""

def __str__(self):
return f"Binary.{super().__str__()}"
Expand Down Expand Up @@ -354,3 +363,35 @@ def __eq__(self, other):

def __repr__(self):
return f"Binary({self.op}, {self.left}, {self.right}, {self.type})"


@typing.final
class Index(Expr):
"""An indexing expression.
Args:
target: The object being indexed.
index: The expression doing the indexing.
type: The resolved type of the result.
"""

__slots__ = ("target", "index")

def __init__(self, target: Expr, index: Expr, type: types.Type):
self.target = target
self.index = index
self.type = type

def accept(self, visitor, /):
return visitor.visit_index(self)

def __eq__(self, other):
return (
isinstance(other, Index)
and self.type == other.type
and self.target == other.target
and self.index == other.index
)

def __repr__(self):
return f"Index({self.target}, {self.index}, {self.type})"
20 changes: 20 additions & 0 deletions qiskit/circuit/classical/expr/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def visit_binary(self, node: expr.Binary, /) -> _T_co: # pragma: no cover
def visit_cast(self, node: expr.Cast, /) -> _T_co: # pragma: no cover
return self.visit_generic(node)

def visit_index(self, node: expr.Index, /) -> _T_co: # pragma: no cover
return self.visit_generic(node)


class _VarWalkerImpl(ExprVisitor[typing.Iterable[expr.Var]]):
__slots__ = ()
Expand All @@ -75,6 +78,10 @@ def visit_binary(self, node, /):
def visit_cast(self, node, /):
yield from node.operand.accept(self)

def visit_index(self, node, /):
yield from node.target.accept(self)
yield from node.index.accept(self)


_VAR_WALKER = _VarWalkerImpl()

Expand Down Expand Up @@ -164,6 +171,16 @@ def visit_cast(self, node, /):
self.other = self.other.operand
return node.operand.accept(self)

def visit_index(self, node, /):
if self.other.__class__ is not node.__class__ or self.other.type != node.type:
return False
other = self.other
self.other = other.target
if not node.target.accept(self):
return False
self.other = other.index
return node.index.accept(self)


def structurally_equivalent(
left: expr.Expr,
Expand Down Expand Up @@ -235,6 +252,9 @@ def visit_binary(self, node, /):
def visit_cast(self, node, /):
return False

def visit_index(self, node, /):
return node.target.accept(self)


_IS_LVALUE = _IsLValueImpl()

Expand Down
8 changes: 8 additions & 0 deletions qiskit/qasm3/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ class Op(enum.Enum):
GREATER_EQUAL = ">="
EQUAL = "=="
NOT_EQUAL = "!="
SHIFT_LEFT = "<<"
SHIFT_RIGHT = ">>"

def __init__(self, op: Op, left: Expression, right: Expression):
self.op = op
Expand All @@ -265,6 +267,12 @@ def __init__(self, type: ClassicalType, operand: Expression):
self.operand = operand


class Index(Expression):
def __init__(self, target: Expression, index: Expression):
self.target = target
self.index = index


class IndexSet(ASTNode):
"""
A literal index set of values::
Expand Down
3 changes: 3 additions & 0 deletions qiskit/qasm3/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,3 +1165,6 @@ def visit_binary(self, node, /):
return ast.Binary(
ast.Binary.Op[node.op.name], node.left.accept(self), node.right.accept(self)
)

def visit_index(self, node, /):
return ast.Index(node.target.accept(self), node.index.accept(self))
24 changes: 19 additions & 5 deletions qiskit/qasm3/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@
# indexing and casting are all higher priority than these, so we just ignore them.
_BindingPower = collections.namedtuple("_BindingPower", ("left", "right"), defaults=(255, 255))
_BINDING_POWER = {
# Power: (21, 22)
# Power: (24, 23)
#
ast.Unary.Op.LOGIC_NOT: _BindingPower(right=20),
ast.Unary.Op.BIT_NOT: _BindingPower(right=20),
ast.Unary.Op.LOGIC_NOT: _BindingPower(right=22),
ast.Unary.Op.BIT_NOT: _BindingPower(right=22),
#
# Multiplication/division/modulo: (17, 18)
# Addition/subtraction: (15, 16)
# Multiplication/division/modulo: (19, 20)
# Addition/subtraction: (17, 18)
#
ast.Binary.Op.SHIFT_LEFT: _BindingPower(15, 16),
ast.Binary.Op.SHIFT_RIGHT: _BindingPower(15, 16),
#
ast.Binary.Op.LESS: _BindingPower(13, 14),
ast.Binary.Op.LESS_EQUAL: _BindingPower(13, 14),
Expand Down Expand Up @@ -332,6 +335,17 @@ def _visit_Cast(self, node: ast.Cast):
self.visit(node.operand)
self.stream.write(")")

def _visit_Index(self, node: ast.Index):
if isinstance(node.target, (ast.Unary, ast.Binary)):
self.stream.write("(")
self.visit(node.target)
self.stream.write(")")
else:
self.visit(node.target)
self.stream.write("[")
self.visit(node.index)
self.stream.write("]")

def _visit_ClassicalDeclaration(self, node: ast.ClassicalDeclaration) -> None:
self._start_line()
self.visit(node.type)
Expand Down
15 changes: 15 additions & 0 deletions qiskit/qpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,21 @@
Notably, this new type-code indexes into pre-defined variables from the circuit header, rather than
redefining the variable again in each location it is used.
Changes to EXPRESSION
---------------------
The EXPRESSION type code has a new possible entry, ``i``, corresponding to :class:`.expr.Index`
nodes.
====================== ========= ======================================================= ========
Qiskit class Type code Payload Children
====================== ========= ======================================================= ========
:class:`~.expr.Index` ``i`` No additional payload. The children are the target 2
and the index, in that order.
====================== ========= ======================================================= ========
.. _qpy_version_11:
Version 11
Expand Down
Loading

0 comments on commit d6c74c2

Please sign in to comment.