Skip to content

Commit

Permalink
Merge pull request #255 from bachbao/surround-filter
Browse files Browse the repository at this point in the history
Modifying surround function
  • Loading branch information
edyounis authored Jul 10, 2024
2 parents 44040ff + 590f5ce commit f427e4c
Showing 1 changed file with 44 additions and 5 deletions.
49 changes: 44 additions & 5 deletions bqskit/ir/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2216,21 +2216,26 @@ def score(node: Node) -> int:
return sum(op[1].num_qudits for op in node[1])

best_score = score(init_node)
_logger.debug(f'best_score: {best_score}')
best_region = self.get_region({(point[0], init_op.location[0])})

# Exhaustive Search
while len(frontier) > 0:
#_logger.debug(f'Current frontiers:{frontier}')
best_score_flag = False
node = frontier.pop(0)
_logger.debug('popped node:')
_logger.debug(node[0])
_logger.debug(f'Items remaining in the frontier: {len(frontier)}')

if filter is not None and not filter(node):
_logger.debug('Node failed filter; skipping.')
_logger.debug(f'Node failed location: {node[2]}')
continue

# Evaluate node
if score(node) > best_score:
_logger.debug(f'current_score: {score(node)}')
if score(node) > best_score: # or (score(node) == best_score and len(node[2]) > len(best_region)):
# Calculate region from best node and return
points = {(cycle, op.location[0]) for cycle, op in node[1]}

Expand All @@ -2243,13 +2248,39 @@ def score(node: Node) -> int:
except ValueError:
if fail_quickly:
continue
elif score(node) == best_score:
_logger.debug(f'current score is {score(node)} and is equal to the best score.')
best_score_flag = True
# Bao's comment: The reason why I considering the node where the score is at least the same with
# the best score is due to the need to expand the region to at least the number of qubits that we required
# Why this works: as the filtering now limit the region to what we want, the node that we assuming should
# satisfy the filtering and bounded by the region that we want. Therefore, by using union to take into account
# all regions, we hope to cover the amount of qubits that we want

# elif score(node) == best_score and len(best_region) < num_qudits:
# # Calculate region from best node and return
# points = {(cycle, op.location[0]) for cycle, op in node[1]}
# try:
# new_region = self.get_region(points)
# _logger.debug(f'new region: {new_region}')
# # If two region is different, merge them. As the region is bounded, we do not have the case where two
# # regions has no overlap
# if new_region.location != best_region.location:
# best_region = best_region.union(new_region)
# best_score = score(node)
# _logger.debug(f'new best: {best_region}.')
#
# # Need to reject bad regions
# except ValueError:
# if fail_quickly:
# continue

# Expand node
absorbed_gates: set[tuple[int, Operation]] = set()
branches: set[tuple[int, int, Operation]] = set()
before_branch_half_wires: dict[int, HalfWire] = {}
for i, half_wire in enumerate(node[0]):

_logger.debug(f"Exploring {half_wire} .....")
cycle_index, qudit_index = half_wire[0]
step = -1 if half_wire[1] == 'left' else 1

Expand All @@ -2270,8 +2301,11 @@ def score(node: Node) -> int:
# Stop when exploring previously explored points
point = CircuitPoint(cycle_index, qudit_index)
if point in node[3]:
break
node[3].add(point)
if not best_score_flag:
_logger.debug(f"Skipping op {self[point]} because previously seen.")
break
else:
node[3].add(point)

# Continue until next operation
if self.is_point_idle(point):
Expand Down Expand Up @@ -2299,6 +2333,7 @@ def score(node: Node) -> int:
break

# Otherwise branch on the operation
_logger.debug(f"Adding {(i, cycle_index, op)} to branch")
branches.add((i, cycle_index, op))

# Track state of half wire right before branch
Expand All @@ -2308,7 +2343,7 @@ def score(node: Node) -> int:

# Compute children and extend frontier
for half_wire_index, cycle_index, op in branches:

_logger.debug(f"Expanding branch {(half_wire_index, cycle_index, op)}")
child_half_wires = [
half_wire
for i, half_wire in before_branch_half_wires.items()
Expand All @@ -2330,6 +2365,7 @@ def score(node: Node) -> int:
expansion = left_expansion + right_expansion

# Branch/Gate not taken
_logger.debug(f"Branch/gate not taken: {(child_half_wires, node[1] | absorbed_gates, node[2], node[3])}")
frontier.append((
child_half_wires,
node[1] | absorbed_gates,
Expand All @@ -2339,6 +2375,8 @@ def score(node: Node) -> int:

# Branch/Gate taken
op_points = {CircuitPoint(cycle_index, q) for q in op.location}
_logger.debug(
f"Branch/Gate taken: {(list(set(child_half_wires + expansion)), node[1] | absorbed_gates | {(cycle_index, op)}, node[2].union(op.location), node[3] | op_points)}")
frontier.append((
list(set(child_half_wires + expansion)),
node[1] | absorbed_gates | {(cycle_index, op)},
Expand All @@ -2349,6 +2387,7 @@ def score(node: Node) -> int:
# Append terminal node to handle absorbed gates with no branches
if len(node[1] | absorbed_gates) != len(node[1]):
frontier.append(([], node[1] | absorbed_gates, *node[2:]))
_logger.debug(f"Terminal node {frontier[-1]}")

return best_region

Expand Down

0 comments on commit f427e4c

Please sign in to comment.