Skip to content

Commit

Permalink
bugfix(backend): fix some bugs in axons set (#136)
Browse files Browse the repository at this point in the history
* fix bug in coreblock axons

* add test for ordered axons

* 🚨 auto fix by pre-commit hooks

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and KafCoppelia committed Nov 29, 2024
1 parent d216eaa commit 52c5231
Show file tree
Hide file tree
Showing 7 changed files with 392 additions and 65 deletions.
74 changes: 74 additions & 0 deletions paibox/backend/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,80 @@ def _degree_check(
)


def find_cycles(directed_edges: Mapping[_NT, Iterable[_NT]]) -> list[list[_NT]]:
cycles: list[list[_NT]] = []
visited: set[_NT] = set()
stack: list[_NT] = []
stack_set: set[_NT] = set() # 方便快速检查路径中的节点

# 深度优先搜索的辅助函数
def dfs(node: _NT):
if node in stack_set: # 检测到环
cycle_start_index = stack.index(node)
cycles.append(stack[cycle_start_index:])
return
if node in visited:
return

visited.add(node)
stack.append(node)
stack_set.add(node)

for neighbor in directed_edges.get(node, []):
dfs(neighbor)

stack.pop()
stack_set.remove(node)

# 遍历每个节点,查找所有可能的环
for node in directed_edges:
if node not in visited:
dfs(node)

return cycles


def merge_overlap(groups: Iterable[Iterable[_NT]]) -> list[list[_NT]]:
# 并查集数据结构
parent: dict[_NT, _NT] = dict()

# 查找集合的根节点
def find(x):
if parent[x] != x:
parent[x] = find(parent[x])
return parent[x]

# 合并两个集合
def union(x, y):
rootX = find(x)
rootY = find(y)
if rootX != rootY:
parent[rootY] = rootX

# 初始化并查集
for group in groups:
for element in group:
if element not in parent:
parent[element] = element

# 合并所有相互重叠的环
for group in groups:
first_element = group[0]
for element in group[1:]:
union(first_element, element)

# 根据并查集结果,将所有节点归类到同一个集合中
merged_groups: dict[_NT, list[_NT]] = dict()
for element in parent:
root = find(element)
if root not in merged_groups:
merged_groups[root] = []
merged_groups[root].append(element)

# 将结果转换为列表列表形式
return list(merged_groups.values())


def toposort(directed_edges: Mapping[_NT, Iterable[_NT]]) -> list[_NT]:
"""
Topological sort algorithm by Kahn [1]_.
Expand Down
67 changes: 57 additions & 10 deletions paibox/backend/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,24 @@
OutputDestConf,
)
from .context import _BACKEND_CONTEXT, set_cflag
from .graphs import PAIGraph, get_node_degrees, get_succ_cb_by_node, toposort
from .graphs import (
PAIGraph,
find_cycles,
get_node_degrees,
get_succ_cb_by_node,
merge_overlap,
toposort,
)
from .placement import CoreBlock, aligned_coords, max_lcn_of_cb
from .routing import RoutingGroup, RoutingManager
from .types import NeuSegment, NodeDegree, NodeType, SourceNodeType, is_iw8
from .types import (
MergedSuccGroup,
NeuSegment,
NodeDegree,
NodeType,
SourceNodeType,
is_iw8,
)

__all__ = ["Mapper"]

Expand Down Expand Up @@ -202,10 +216,19 @@ def untwist_branch_nodes(self) -> None:

def build_core_blocks(self) -> None:
"""Build core blocks based on partitioned edges."""
merged_sgrps = self.graph.graph_partition()
merged_sgrps: list[MergedSuccGroup] = self.graph.graph_partition()
merged_sgrps: list[MergedSuccGroup] = cycle_merge(merged_sgrps)

for msgrp in merged_sgrps:
self.routing_groups.append(RoutingGroup.build(msgrp))
self.routing_groups.append(RoutingGroup.build(msgrp, True))

routing_groups: list[RoutingGroup] = list()
for rg in self.routing_groups:
routing_groups.extend(rg.optimize_group())
self.routing_groups = routing_groups

for rg in self.routing_groups:
rg.dump()

for rg in self.routing_groups:
self.core_blocks += rg.core_blocks
Expand All @@ -214,7 +237,7 @@ def build_core_blocks(self) -> None:
succ_cbs: list[CoreBlock] = []
# cur_cb == cb is possible
for cb in self.core_blocks:
if any(d for d in cur_cb.dest if d in cb.source):
if any(d for d in cur_cb.dest if d in cb.ordered_axons):
succ_cbs.append(cb)

self.succ_core_blocks[cur_cb] = succ_cbs
Expand Down Expand Up @@ -274,8 +297,8 @@ def lcn_ex_adjustment(self) -> None:

def cb_axon_grouping(self) -> None:
"""The axons are grouped after the LCN has been modified & locked."""
for rg in self.routing_groups:
rg.group_axons()
for core_block in self.core_blocks:
core_block.group_axons()

def graph_optimization(self) -> None:
optimized = self.graph.graph_optimization(self.core_blocks, self.routing_groups)
Expand Down Expand Up @@ -416,7 +439,7 @@ def _inpproj_config_export(self) -> InputNodeConf:
# LCN of `input_cbs` are the same.
input_cb = input_cbs[0]
axon_coords = aligned_coords(
slice(0, input_cb.n_axon_of(input_cb.source.index(inode)), 1),
slice(0, input_cb.n_axon_of(input_cb.ordered_axons.index(inode)), 1),
input_cb.axon_segments[inode],
1,
input_cb.n_timeslot,
Expand Down Expand Up @@ -646,7 +669,7 @@ def find_axon(self, neuron: Neuron, *, verbose: int = 0) -> None:

for cb in self.core_blocks:
# Find neuron in one or more core blocks.
if neuron in cb.source:
if neuron in cb.ordered_axons:
print(f"axons {neuron.name} placed in {cb.name}, LCN_{1 << cb.lcn_ex}X")
axon_segment = cb.axon_segments[neuron]
print(
Expand All @@ -663,11 +686,35 @@ def _find_dest_cb_by_nseg(
self, neu_seg: NeuSegment, cb: CoreBlock
) -> list[CoreBlock]:
succ_cbs = self.succ_core_blocks[cb]
dest_cb_of_nseg = [cb for cb in succ_cbs if neu_seg.target in cb.source]
dest_cb_of_nseg = [cb for cb in succ_cbs if neu_seg.target in cb.ordered_axons]

return dest_cb_of_nseg


def cycle_merge(merged_sgrps: list[MergedSuccGroup]):
succ_merged_sgrps: dict[MergedSuccGroup, list[MergedSuccGroup]] = dict()
for msgrp in merged_sgrps:
succ_merged_sgrps[msgrp] = []
nodes = set(msgrp.nodes)
for _msgrp in merged_sgrps:
if msgrp == _msgrp:
continue
if not nodes.isdisjoint(_msgrp.input_nodes):
succ_merged_sgrps[msgrp].append(_msgrp)

cycles: list[list[MergedSuccGroup]] = find_cycles(succ_merged_sgrps)
merged_cycles: list[list[MergedSuccGroup]] = merge_overlap(cycles)

processed_merged_cycles: list[MergedSuccGroup] = list()
remaining_merged_sgrps: set[MergedSuccGroup] = set(merged_sgrps)
for merged_cycle in merged_cycles:
processed_merged_cycles.append(MergedSuccGroup.merge(merged_cycle))
for msgrp in merged_cycle:
remaining_merged_sgrps.remove(msgrp)
processed_merged_cycles.extend(remaining_merged_sgrps)
return processed_merged_cycles


def group_by(dict_: dict, keyfunc=lambda item: item):
"""Groups the given list or dictionary by the value returned by ``keyfunc``."""
d = defaultdict(list)
Expand Down
33 changes: 22 additions & 11 deletions paibox/backend/placement.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
self._parents = parents
self.rt_mode = mode
self.seed = seed
self._lcn_ex = self._n_axon2lcn_ex()
self._lcn_ex = LCN_EX.LCN_1X

self.target_lcn = LCN_EX.LCN_1X
self._lcn_locked = False
Expand All @@ -102,7 +102,7 @@ def __init__(
self.core_placements = dict()
self.axon_segments = dict()
self.neuron_segs_of_cb = []
self.ordered_axons: list[SourceNodeType] = []
self._ordered_axons: list[SourceNodeType] = []
"""Axons in private + multicast order."""

def group_neurons(
Expand Down Expand Up @@ -172,7 +172,7 @@ def obj(self) -> tuple[FullConnectedSyn, ...]:

@property
def shape(self) -> tuple[int, int]:
return (len(self.source), len(self.dest))
return (len(self.ordered_axons), len(self.dest))

@property
def source(self) -> list[SourceNodeType]:
Expand All @@ -190,7 +190,7 @@ def dest(self) -> list[DestNodeType]:

def n_axon_of(self, index: int) -> int:
"""Get the #N of axons of `index`-th source neuron."""
return self.axons[index].num_out
return self.ordered_axons[index].num_out

"""Boundary limitations"""

Expand Down Expand Up @@ -275,7 +275,7 @@ def pool_max(self) -> MaxPoolingEnable:

@property
def n_axon(self) -> int:
return sum(s.num_out for s in self.axons)
return sum(s.num_out for s in self.ordered_axons)

@property
def n_fanout(self) -> int:
Expand Down Expand Up @@ -307,17 +307,21 @@ def n_neuron_of_plm(self) -> list[int]:
for neuron_segs in self.neuron_segs_of_cb
]

def group_axons(self, multicast_axons: list[SourceNodeType] = []) -> None:
@property
def ordered_axons(self) -> list[SourceNodeType]:
return self._ordered_axons

@ordered_axons.setter
def ordered_axons(self, axons: list[SourceNodeType]):
self._ordered_axons = axons
self._lcn_ex = self._n_axon2lcn_ex()

def group_axons(self) -> None:
"""Group the axons, including the private & the multicast parts.
NOTE: Take the union of the private axons & the multicast axons, but sort the multicast axons first, then the \
axons that are in the private part and not in the multicast part.
"""
if not self._lcn_locked:
raise GraphBuildError("group axons after 'lcn_ex' is locked.")

axons = multicast_axons + [ax for ax in self.axons if ax not in multicast_axons]
self.ordered_axons = axons
self.axon_segments = get_axon_segments(
self.ordered_axons, self.n_timeslot, self.n_fanin_base
)
Expand Down Expand Up @@ -435,6 +439,13 @@ def export_core_plm_config(cls, cb: "CoreBlock") -> CoreConfInChip:

return cb_config

def dump(self, i: int = 0) -> None:
tabs = "\t" * i
print(f"{tabs}{self.name} with {self.n_core_required} cores:")
print(f"{tabs}\tLCN: {self.lcn_ex}")
for edge in self._parents:
print(f"{tabs}\t{edge.name}: {edge.source.name} -> {edge.target.name}")


class CorePlacement(CoreAbstract):
parent: CoreBlock
Expand Down
Loading

0 comments on commit 52c5231

Please sign in to comment.