-
Notifications
You must be signed in to change notification settings - Fork 46
/
_310.java
76 lines (69 loc) · 2.46 KB
/
_310.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import java.util.*;
/**
* LeetCode 310 - Minimum Height Trees
*
* It is easy to see that the root of an MHT has to be the middle point (or two middle points) of the longest path of the tree.
* Though multiple longest paths can appear in an unrooted tree, they must share the same middle point(s).
*
* Computing the longest path of a unrooted tree can be done, in O(n) time, by tree dp, or simply 2 tree traversals (dfs or bfs).
* The following code does the latter.
*
* Randomly select a node x as the root, do a dfs/bfs to find the node y that has the longest distance from y.
* Then y must be one of the endpoints on some longest path.
* Let y the new root, and do another dfs/bfs. Find the node z that has the longest distance from y.
*
* Now, the path from y to z is the longest one, and thus its middle point(s) is the answer.
*/
public class _310 {
int n;
List<Integer>[] e;
private void bfs(int start, int[] dist, int[] pre) {
boolean[] visited = new boolean[n];
Queue<Integer> queue = new ArrayDeque<>();
queue.add(start);
dist[start] = 0;
visited[start] = true;
pre[start] = -1;
while (!queue.isEmpty()) {
int u = queue.poll();
for (int v : e[u])
if (!visited[v]) {
visited[v] = true;
dist[v] = dist[u] + 1;
queue.add(v);
pre[v] = u;
}
}
}
public List<Integer> findMinHeightTrees(int n, int[][] edges) {
if (n <= 0) return new ArrayList<>();
this.n = n;
e = new List[n];
for (int i = 0; i < n; i++)
e[i] = new ArrayList<>();
for (int[] pair : edges) {
int u = pair[0];
int v = pair[1];
e[u].add(v);
e[v].add(u);
}
int[] d1 = new int[n];
int[] d2 = new int[n];
int[] pre = new int[n];
bfs(0, d1, pre);
int u = 0;
for (int i = 0; i < n; i++)
if (d1[i] > d1[u]) u = i;
bfs(u, d2, pre);
int v = 0;
for (int i = 0; i < n; i++)
if (d2[i] > d2[v]) v = i;
List<Integer> list = new ArrayList<>();
while (v != -1) {
list.add(v);
v = pre[v];
}
if (list.size() % 2 == 1) return Arrays.asList(list.get(list.size() / 2));
else return Arrays.asList(list.get(list.size() / 2 - 1), list.get(list.size() / 2));
}
}