Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modifying surround function #255

Merged
merged 2 commits into from
Jul 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions bqskit/ir/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2216,21 +2216,26 @@ def score(node: Node) -> int:
return sum(op[1].num_qudits for op in node[1])

best_score = score(init_node)
_logger.debug(f'best_score: {best_score}')
best_region = self.get_region({(point[0], init_op.location[0])})

# Exhaustive Search
while len(frontier) > 0:
#_logger.debug(f'Current frontiers:{frontier}')
best_score_flag = False
node = frontier.pop(0)
_logger.debug('popped node:')
_logger.debug(node[0])
_logger.debug(f'Items remaining in the frontier: {len(frontier)}')

if filter is not None and not filter(node):
_logger.debug('Node failed filter; skipping.')
_logger.debug(f'Node failed location: {node[2]}')
continue

# Evaluate node
if score(node) > best_score:
_logger.debug(f'current_score: {score(node)}')
if score(node) > best_score: # or (score(node) == best_score and len(node[2]) > len(best_region)):
# Calculate region from best node and return
points = {(cycle, op.location[0]) for cycle, op in node[1]}

Expand All @@ -2243,13 +2248,39 @@ def score(node: Node) -> int:
except ValueError:
if fail_quickly:
continue
elif score(node) == best_score:
_logger.debug(f'current score is {score(node)} and is equal to the best score.')
best_score_flag = True
# Bao's comment: The reason why I considering the node where the score is at least the same with
# the best score is due to the need to expand the region to at least the number of qubits that we required
# Why this works: as the filtering now limit the region to what we want, the node that we assuming should
# satisfy the filtering and bounded by the region that we want. Therefore, by using union to take into account
# all regions, we hope to cover the amount of qubits that we want

# elif score(node) == best_score and len(best_region) < num_qudits:
# # Calculate region from best node and return
# points = {(cycle, op.location[0]) for cycle, op in node[1]}
# try:
# new_region = self.get_region(points)
# _logger.debug(f'new region: {new_region}')
# # If two region is different, merge them. As the region is bounded, we do not have the case where two
# # regions has no overlap
# if new_region.location != best_region.location:
# best_region = best_region.union(new_region)
# best_score = score(node)
# _logger.debug(f'new best: {best_region}.')
#
# # Need to reject bad regions
# except ValueError:
# if fail_quickly:
# continue

# Expand node
absorbed_gates: set[tuple[int, Operation]] = set()
branches: set[tuple[int, int, Operation]] = set()
before_branch_half_wires: dict[int, HalfWire] = {}
for i, half_wire in enumerate(node[0]):

_logger.debug(f"Exploring {half_wire} .....")
cycle_index, qudit_index = half_wire[0]
step = -1 if half_wire[1] == 'left' else 1

Expand All @@ -2270,8 +2301,11 @@ def score(node: Node) -> int:
# Stop when exploring previously explored points
point = CircuitPoint(cycle_index, qudit_index)
if point in node[3]:
break
node[3].add(point)
if not best_score_flag:
_logger.debug(f"Skipping op {self[point]} because previously seen.")
break
else:
node[3].add(point)

# Continue until next operation
if self.is_point_idle(point):
Expand Down Expand Up @@ -2299,6 +2333,7 @@ def score(node: Node) -> int:
break

# Otherwise branch on the operation
_logger.debug(f"Adding {(i, cycle_index, op)} to branch")
branches.add((i, cycle_index, op))

# Track state of half wire right before branch
Expand All @@ -2308,7 +2343,7 @@ def score(node: Node) -> int:

# Compute children and extend frontier
for half_wire_index, cycle_index, op in branches:

_logger.debug(f"Expanding branch {(half_wire_index, cycle_index, op)}")
child_half_wires = [
half_wire
for i, half_wire in before_branch_half_wires.items()
Expand All @@ -2330,6 +2365,7 @@ def score(node: Node) -> int:
expansion = left_expansion + right_expansion

# Branch/Gate not taken
_logger.debug(f"Branch/gate not taken: {(child_half_wires, node[1] | absorbed_gates, node[2], node[3])}")
frontier.append((
child_half_wires,
node[1] | absorbed_gates,
Expand All @@ -2339,6 +2375,8 @@ def score(node: Node) -> int:

# Branch/Gate taken
op_points = {CircuitPoint(cycle_index, q) for q in op.location}
_logger.debug(
f"Branch/Gate taken: {(list(set(child_half_wires + expansion)), node[1] | absorbed_gates | {(cycle_index, op)}, node[2].union(op.location), node[3] | op_points)}")
frontier.append((
list(set(child_half_wires + expansion)),
node[1] | absorbed_gates | {(cycle_index, op)},
Expand All @@ -2349,6 +2387,7 @@ def score(node: Node) -> int:
# Append terminal node to handle absorbed gates with no branches
if len(node[1] | absorbed_gates) != len(node[1]):
frontier.append(([], node[1] | absorbed_gates, *node[2:]))
_logger.debug(f"Terminal node {frontier[-1]}")

return best_region

Expand Down