-
Notifications
You must be signed in to change notification settings - Fork 1
/
sol.java
77 lines (67 loc) · 1.89 KB
/
sol.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import java.lang.*;
import java.io.*;
import java.math.*;
import java.util.*;
public class Solution{
private static final Scanner scanner = new Scanner(System.in);
private static int n, k, u, v, c;
private static final int MOD = 1000000007;
static long binpow(long a, long b) {
a %= MOD;
long res = 1;
while (b > 0) {
if ((b & 1) == 1) res = res * a % MOD;
a = a * a % MOD;
b >>= 1;
}
return res;
}
public static void main(String[] args) {
n = scanner.nextInt();
k = scanner.nextInt();
UF uf = new UF(n);
for (int i = 0; i < n-1; i++) {
u = scanner.nextInt(); v = scanner.nextInt(); c = scanner.nextInt();
u--; v--;
if (c == 0) uf.union(u,v);
}
long ans = binpow(n,k);
for (int i = 0; i < n; i++) {
if (uf.isp(i)) ans = (ans + MOD - ((binpow(uf.getSize(i),k) % MOD))) % MOD;
}
System.out.println(ans);
}
}
class UF {
private int[] parent;
private int[] sz;
public UF(int size) {
parent = new int[size];
sz = new int[size];
for (int i = 0; i < size; i++) {
parent[i] = i;
sz[i] = 1;
}
}
public int find(int x) {
return (parent[x] == x) ? x : (parent[x] = find(parent[x]));
}
public boolean connected(int x, int y) {
return this.find(x) == this.find(y);
}
public void union(int x, int y) {
if (this.connected(x,y)) return;
int px = find(x), py = find(y);
if (sz[px] > sz[py]) {
parent[py] = px; sz[px] += sz[py];
}
else {
parent[px] = py; sz[py] += sz[px];
}
}
// Specific to this question only
public int getSize(int x) {
return sz[x];
}
public boolean isp(int x) { return parent[x] == x; }
}