-
-
Notifications
You must be signed in to change notification settings - Fork 7.3k
/
dsu_path_compression.cpp
214 lines (203 loc) · 6.79 KB
/
dsu_path_compression.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
/**
* @file
* @brief [DSU (Disjoint
* sets)](https://en.wikipedia.org/wiki/Disjoint-set-data_structure)
* @details
* It is a very powerful data structure that keeps track of different
* clusters(sets) of elements, these sets are disjoint(doesnot have a common
* element). Disjoint sets uses cases : for finding connected components in a
* graph, used in Kruskal's algorithm for finding Minimum Spanning tree.
* Operations that can be performed:
* 1) UnionSet(i,j): add(element i and j to the set)
* 2) findSet(i): returns the representative of the set to which i belogngs to.
* 3) get_max(i),get_min(i) : returns the maximum and minimum
* Below is the class-based approach which uses the heuristic of path
* compression. Using path compression in findSet(i),we are able to get to the
* representative of i in O(1) time.
* @author [AayushVyasKIIT](https://github.com/AayushVyasKIIT)
* @see dsu_union_rank.cpp
*/
#include <cassert> /// for assert
#include <cstdint>
#include <iostream> /// for IO operations
#include <vector> /// for std::vector
using std::cout;
using std::endl;
using std::vector;
/**
* @brief Disjoint sets union data structure, class based representation.
* @param n number of elements
*/
class dsu {
private:
vector<uint64_t> p; ///< keeps track of the parent of ith element
vector<uint64_t> depth; ///< tracks the depth(rank) of i in the tree
vector<uint64_t> setSize; ///< size of each chunk(set)
vector<uint64_t> maxElement; ///< maximum of each set to which i belongs to
vector<uint64_t> minElement; ///< minimum of each set to which i belongs to
public:
/**
* @brief contructor for initialising all data members.
* @param n number of elements
*/
explicit dsu(uint64_t n) {
p.assign(n, 0);
/// initially, all of them are their own parents
for (uint64_t i = 0; i < n; i++) {
p[i] = i;
}
/// initially all have depth are equals to zero
depth.assign(n, 0);
maxElement.assign(n, 0);
minElement.assign(n, 0);
for (uint64_t i = 0; i < n; i++) {
depth[i] = 0;
maxElement[i] = i;
minElement[i] = i;
}
setSize.assign(n, 0);
/// initially set size will be equals to one
for (uint64_t i = 0; i < n; i++) {
setSize[i] = 1;
}
}
/**
* @brief Method to find the representative of the set to which i belongs
* to, T(n) = O(1)
* @param i element of some set
* @returns representative of the set to which i belongs to.
*/
uint64_t findSet(uint64_t i) {
/// using path compression
if (p[i] == i) {
return i;
}
return (p[i] = findSet(p[i]));
}
/**
* @brief Method that combines two disjoint sets to which i and j belongs to
* and make a single set having a common representative.
* @param i element of some set
* @param j element of some set
* @returns void
*/
void UnionSet(uint64_t i, uint64_t j) {
/// check if both belongs to the same set or not
if (isSame(i, j)) {
return;
}
// we find the representative of the i and j
uint64_t x = findSet(i);
uint64_t y = findSet(j);
/// always keeping the min as x
/// shallow tree
if (depth[x] > depth[y]) {
std::swap(x, y);
}
/// making the shallower root's parent the deeper root
p[x] = y;
/// if same depth, then increase one's depth
if (depth[x] == depth[y]) {
depth[y]++;
}
/// total size of the resultant set
setSize[y] += setSize[x];
/// changing the maximum elements
maxElement[y] = std::max(maxElement[x], maxElement[y]);
minElement[y] = std::min(minElement[x], minElement[y]);
}
/**
* @brief A utility function which check whether i and j belongs to
* same set or not
* @param i element of some set
* @param j element of some set
* @returns `true` if element `i` and `j` ARE in the same set
* @returns `false` if element `i` and `j` are NOT in same set
*/
bool isSame(uint64_t i, uint64_t j) {
if (findSet(i) == findSet(j)) {
return true;
}
return false;
}
/**
* @brief prints the minimum, maximum and size of the set to which i belongs
* to
* @param i element of some set
* @returns void
*/
vector<uint64_t> get(uint64_t i) {
vector<uint64_t> ans;
ans.push_back(get_min(i));
ans.push_back(get_max(i));
ans.push_back(size(i));
return ans;
}
/**
* @brief A utility function that returns the size of the set to which i
* belongs to
* @param i element of some set
* @returns size of the set to which i belongs to
*/
uint64_t size(uint64_t i) { return setSize[findSet(i)]; }
/**
* @brief A utility function that returns the max element of the set to
* which i belongs to
* @param i element of some set
* @returns maximum of the set to which i belongs to
*/
uint64_t get_max(uint64_t i) { return maxElement[findSet(i)]; }
/**
* @brief A utility function that returns the min element of the set to
* which i belongs to
* @param i element of some set
* @returns minimum of the set to which i belongs to
*/
uint64_t get_min(uint64_t i) { return minElement[findSet(i)]; }
};
/**
* @brief Self-test implementations, 1st test
* @returns void
*/
static void test1() {
// the minimum, maximum, and size of the set
uint64_t n = 10; ///< number of items
dsu d(n + 1); ///< object of class disjoint sets
// set 1
d.UnionSet(1, 2); // performs union operation on 1 and 2
d.UnionSet(1, 4); // performs union operation on 1 and 4
vector<uint64_t> ans = {1, 4, 3};
for (uint64_t i = 0; i < ans.size(); i++) {
assert(d.get(4).at(i) == ans[i]); // makes sure algorithm works fine
}
cout << "1st test passed!" << endl;
}
/**
* @brief Self-implementations, 2nd test
* @returns void
*/
static void test2() {
// the minimum, maximum, and size of the set
uint64_t n = 10; ///< number of items
dsu d(n + 1); ///< object of class disjoint sets
// set 1
d.UnionSet(3, 5);
d.UnionSet(5, 6);
d.UnionSet(5, 7);
vector<uint64_t> ans = {3, 7, 4};
for (uint64_t i = 0; i < ans.size(); i++) {
assert(d.get(3).at(i) == ans[i]); // makes sure algorithm works fine
}
cout << "2nd test passed!" << endl;
}
/**
* @brief Main function
* @returns 0 on exit
* */
int main() {
uint64_t n = 10; ///< number of items
dsu d(n + 1); ///< object of class disjoint sets
test1(); // run 1st test case
test2(); // run 2nd test case
return 0;
}