Skip to content

Commit

Permalink
Use cpdef's union and find
Browse files Browse the repository at this point in the history
Typecast `union` as `void` to a-void the "C struct/union cannot be declared cpdef" error.
  • Loading branch information
gmou3 committed Apr 22, 2024
1 parent 1280f25 commit fad60ad
Show file tree
Hide file tree
Showing 18 changed files with 203 additions and 330 deletions.
32 changes: 16 additions & 16 deletions src/sage/combinat/bijectionist.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ def set_constant_blocks(self, P):
P = sorted(self._sorter["A"](p) for p in P)
for p in P:
for a in p:
self._P._union(p[0], a)
self._P.union(p[0], a)

self._compute_possible_block_values()

Expand Down Expand Up @@ -1678,7 +1678,7 @@ def merge_until_split():
try:
solution = different_values(tP[i1], tP[i2])
except StopIteration:
tmp_P._union(tP[i1], tP[i2])
tmp_P.union(tP[i1], tP[i2])
if len(multiple_preimages[tZ]) == 2:
del multiple_preimages[tZ]
else:
Expand Down Expand Up @@ -1772,14 +1772,14 @@ def possible_values(self, p=None, optimal=False):
# convert input to set of block representatives
blocks = set()
if p in self._A:
blocks.add(self._P._find(p))
blocks.add(self._P.find(p))
elif isinstance(p, list): # TODO: this looks very brittle
for p1 in p:
if p1 in self._A:
blocks.add(self._P._find(p1))
blocks.add(self._P.find(p1))
elif isinstance(p1, list):
for p2 in p1:
blocks.add(self._P._find(p2))
blocks.add(self._P.find(p2))

if optimal:
if self._bmilp is None:
Expand Down Expand Up @@ -1941,9 +1941,9 @@ def _find_counterexample(self, P, s0, d, on_blocks):

# try to find a solution which has a different
# subdistribution on d than s0
z_in_d = sum(d[p] * bmilp._x[self._P._find(p), z]
z_in_d = sum(d[p] * bmilp._x[self._P.find(p), z]
for p in P
if z in self._possible_block_values[self._P._find(p)])
if z in self._possible_block_values[self._P.find(p)])

# it is sufficient to require that z occurs less often as
# a value among {a | d[a] == 1} than it does in
Expand Down Expand Up @@ -2191,14 +2191,14 @@ def _preprocess_intertwining_relations(self):
# the blocks of the elements of the preimage
updated_images = defaultdict(set) # (p_1,...,p_k) to {a_1,....}
for a_tuple, image_set in images.items():
representatives = tuple(P._find(a) for a in a_tuple)
representatives = tuple(P.find(a) for a in a_tuple)
updated_images[representatives].update(image_set)

# merge blocks
for a_tuple, image_set in updated_images.items():
image = image_set.pop()
while image_set:
P._union(image, image_set.pop())
P.union(image, image_set.pop())
something_changed = True
# we keep a representative
image_set.add(image)
Expand Down Expand Up @@ -2525,7 +2525,7 @@ def __init__(self, bijectionist: Bijectionist, solutions=None):
self._solution_cache = []
if solutions is not None:
for solution in solutions:
self._add_solution({(P._find(a), z): value
self._add_solution({(P.find(a), z): value
for (a, z), value in solution.items()})

def show(self, variables=True):
Expand Down Expand Up @@ -2813,7 +2813,7 @@ def add_alpha_beta_constraints(self):
Z_dict = {z: i for i, z in enumerate(Z)}

for a in self._bijectionist._A:
p = self._bijectionist._P._find(a)
p = self._bijectionist._P.find(a)
for z in self._bijectionist._possible_block_values[p]:
w_index = W_dict[self._bijectionist._alpha(a)]
z_index = Z_dict[z]
Expand Down Expand Up @@ -2867,7 +2867,7 @@ def add_distribution_constraints(self):
tA_sum = [zero] * len(Z_dict)
tZ_sum = [zero] * len(Z_dict)
for a in tA:
p = self._bijectionist._P._find(a)
p = self._bijectionist._P.find(a)
for z in self._bijectionist._possible_block_values[p]:
tA_sum[Z_dict[z]] += self._x[p, z]
for z in tZ:
Expand Down Expand Up @@ -2940,8 +2940,8 @@ def add_intertwining_relation_constraints(self):
continue
a = pi_rho.pi(*a_tuple)
if a in A:
p_tuple = tuple(P._find(a) for a in a_tuple)
p = P._find(a)
p_tuple = tuple(P.find(a) for a in a_tuple)
p = P.find(a)
if (p_tuple, p) not in pi_blocks:
pi_blocks.add((p_tuple, p))
for z_tuple in itertools.product(*[tZ[p] for p in p_tuple]):
Expand Down Expand Up @@ -3008,7 +3008,7 @@ def add_quadratic_relation_constraints(self):
z0 = phi(p)
assert all(phi(a) == z0 for a in block), "phi must be constant on the block %s" % block
for z in self._bijectionist._possible_block_values[p]:
p0 = P._find(psi(z))
p0 = P.find(psi(z))
if z0 in self._bijectionist._possible_block_values[p0]:
c = self._x[p, z] - self._x[p0, z0]
if c.is_zero():
Expand Down Expand Up @@ -3046,7 +3046,7 @@ def add_homomesic_constraints(self):
tZ = self._bijectionist._possible_block_values

def sum_q(q):
return sum(sum(z * self._x[P._find(a), z] for z in tZ[P._find(a)])
return sum(sum(z * self._x[P.find(a), z] for z in tZ[P.find(a)])
for a in q)
q0 = Q[0]
v0 = sum_q(q0)
Expand Down
2 changes: 1 addition & 1 deletion src/sage/combinat/designs/designs_pyx.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def is_group_divisible_design(groups,blocks,v,G=None,K=None,lambd=1,verbose=Fals
for i in range(n):
for j in range(i + 1, n):
if matrix[i * n + j] == 0:
groups._union(i, j)
groups.union(i, j)
groups = list(groups.root_to_elements_dict().values())

# Group sizes are element of G
Expand Down
2 changes: 1 addition & 1 deletion src/sage/combinat/designs/incidence_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ def is_connected(self) -> bool:
for B in self._blocks:
x = B[0]
for i in range(1, len(B)):
D._union(x, B[i])
D.union(x, B[i])
return D.number_of_subsets() == 1

def is_simple(self) -> bool:
Expand Down
22 changes: 11 additions & 11 deletions src/sage/combinat/posets/hasse_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -3230,36 +3230,36 @@ def fill_to_interval(S):
if part: # Skip empty parts
c = part[0]
for e in fill_to_interval(part):
cong._union(e, c)
cong.union(e, c)
t = cong.number_of_subsets()

# Following is needed for cases like
# posets.BooleanLattice(3).congruence([(0,1), (0,2), (0,4)])
for c in list(cong):
r = c[0]
for v in fill_to_interval(c):
cong._union(r, v)
cong.union(r, v)

todo = {cong._find(e) for part in parts for e in part}
todo = {cong.find(e) for part in parts for e in part}

while todo:

# First check if we should stop now.
for a, b in stop_pairs:
if cong._find(a) == cong._find(b):
if cong.find(a) == cong.find(b):
return None

# We take one block and try to find as big interval
# as possible to unify as a new block by the quadrilateral
# argument.
block = sorted(cong.root_to_elements_dict()[cong._find(todo.pop())])
block = sorted(cong.root_to_elements_dict()[cong.find(todo.pop())])

b = block[-1]
for a in block: # Quadrilateral up
for c in self.neighbor_out_iterator(a):
if c not in block:
d = jn[c, b]
if cong._find(d) != cong._find(c):
if cong.find(d) != cong.find(c):
break
else:
continue
Expand All @@ -3271,7 +3271,7 @@ def fill_to_interval(S):
for d in self.neighbor_in_iterator(b):
if d not in block:
c = mt[d, a]
if cong._find(c) != cong._find(d):
if cong.find(c) != cong.find(d):
break
else:
continue
Expand All @@ -3289,10 +3289,10 @@ def fill_to_interval(S):
# recursive process. In particular it may also combine to
# [a, b] block we just used.
while c is not None:
newblock = cong._find(c)
newblock = cong.find(c)
for i in self.interval(c, d):
cong._union(newblock, i)
C = cong.root_to_elements_dict()[cong._find(newblock)]
cong.union(newblock, i)
C = cong.root_to_elements_dict()[cong.find(newblock)]
mins = [i for i in C if all(i_ not in C for i_ in self.neighbor_in_iterator(i))]
maxs = [i for i in C if all(i_ not in C for i_ in self.neighbor_out_iterator(i))]
c = None # To stop loop, if this is not changed below.
Expand All @@ -3305,7 +3305,7 @@ def fill_to_interval(S):
d = jn[d, m]

# This removes duplicates from todo.
todo = {cong._find(x) for x in todo}
todo = {cong.find(x) for x in todo}

return cong

Expand Down
2 changes: 1 addition & 1 deletion src/sage/combinat/set_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2356,7 +2356,7 @@ def from_arcs(self, arcs, n):
"""
P = DisjointSet(range(1, n + 1))
for i, j in arcs:
P._union(i, j)
P.union(i, j)
return self.element_class(self, P)

def from_rook_placement_gamma(self, rooks, n):
Expand Down
10 changes: 5 additions & 5 deletions src/sage/combinat/words/finite_word.py
Original file line number Diff line number Diff line change
Expand Up @@ -5201,16 +5201,16 @@ def overlap_partition(self, other, delay=0, p=None, involution=None):
S = zip(islice(self, int(delay), None), other)
if involution is None:
for a, b in S:
p._union(a, b)
p.union(a, b)
elif isinstance(involution, WordMorphism):
for a, b in S:
p._union(a, b)
p.union(a, b)
# take the first letter of the word
p._union(involution(a)[0], involution(b)[0])
p.union(involution(a)[0], involution(b)[0])
elif callable(involution):
for a, b in S:
p._union(a, b)
p._union(involution(a), involution(b))
p.union(a, b)
p.union(involution(a), involution(b))
else:
raise TypeError("involution (=%s) must be callable" % involution)
return p
Expand Down
2 changes: 1 addition & 1 deletion src/sage/graphs/connectivity.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2540,7 +2540,7 @@ def spqr_tree(G, algorithm="Hopcroft_Tarjan", solver=None, verbose=0,
if cocycles_count[fe] == 2 and len(virtual_edge_to_cycles[fe]) == 2:
# This virtual edge is only between 2 cycles
C1, C2 = virtual_edge_to_cycles[fe]
DS._union(C1, C2)
DS.union(C1, C2)
cycles_list[C1].delete_edge(fe)
cycles_list[C2].delete_edge(fe)
cocycles_count[fe] -= 2
Expand Down
4 changes: 2 additions & 2 deletions src/sage/graphs/generators/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,12 +576,12 @@ def RandomBlockGraph(m, k, kmax=None, incidence_structure=False, seed=None):
# structure to keep a unique identifier per merged vertices
DS = DisjointSet([i for u in B for i in B[u]])
for u, v in T.edges(sort=True, labels=0):
DS._union(choice(B[u]), choice(B[v]))
DS.union(choice(B[u]), choice(B[v]))

# We relabel vertices in the range [0, m*(k-1)] and build the incidence
# structure
new_label = {root: i for i, root in enumerate(DS.root_to_elements_dict())}
IS = [[new_label[DS._find(v)] for v in B[u]] for u in B]
IS = [[new_label[DS.find(v)] for v in B[u]] for u in B]

if incidence_structure:
return IS
Expand Down
8 changes: 4 additions & 4 deletions src/sage/graphs/generic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12657,11 +12657,11 @@ def contract_edges(self, edges):
DS = DisjointSet(self.vertex_iterator())

for u, v, label in edge_list:
DS._union(u, v)
DS.union(u, v)

self.delete_edges(edge_list)
edges_incident = []
vertices = [v for v in vertices if v != DS._find(v)]
vertices = [v for v in vertices if v != DS.find(v)]
if self.is_directed():
for v in vertices:
out_edges = self.edge_boundary([v])
Expand All @@ -12674,8 +12674,8 @@ def contract_edges(self, edges):
self.delete_vertex(v)

for (u, v, label) in edges_incident:
root_u = DS._find(u)
root_v = DS._find(v)
root_u = DS.find(u)
root_v = DS.find(v)
if root_v != root_u or self.allows_loops():
self.add_edge(root_u, root_v, label)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def gamma_classes(graph):
e = frozenset([v1, v])
for vi in component[1:]:
ei = frozenset([vi, v])
pieces._union(e, ei)
pieces.union(e, ei)
return {frozenset(chain.from_iterable(loe)): loe for loe in pieces}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def is_valid_tree_decomposition(G, T):
for Xi in X:
for Xj in T.neighbor_iterator(Xi):
if Xj in X:
D._union(Xi, Xj)
D.union(Xi, Xj)
if D.number_of_subsets() > 1:
return False

Expand Down
8 changes: 4 additions & 4 deletions src/sage/graphs/partial_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,9 @@ def is_partial_cube(G, certificate=False):
if diff not in neighbors:
return fail
neighbor = neighbors[diff]
unionfind._union(contracted.edge_label(v, w),
unionfind.union(contracted.edge_label(v, w),
contracted.edge_label(root, neighbor))
unionfind._union(contracted.edge_label(w, v),
unionfind.union(contracted.edge_label(w, v),
contracted.edge_label(neighbor, root))
labeled.add_edge(v, w)

Expand All @@ -356,13 +356,13 @@ def is_partial_cube(G, certificate=False):
if vi == wi:
return fail
if newgraph.has_edge(vi, wi):
unionfind._union(newgraph.edge_label(vi, wi), t)
unionfind.union(newgraph.edge_label(vi, wi), t)
else:
newgraph.add_edge(vi, wi, t)
contracted = newgraph

# Make a digraph with edges labeled by the equivalence classes in unionfind
g = DiGraph({v: {w: unionfind._find((v, w)) for w in G[v]} for v in G})
g = DiGraph({v: {w: unionfind.find((v, w)) for w in G[v]} for v in G})

# Associates to a vertex the token that acts on it, and check that
# no two edges on a single vertex have the same label
Expand Down
Loading

0 comments on commit fad60ad

Please sign in to comment.