-
Notifications
You must be signed in to change notification settings - Fork 0
/
mmd_code.py
47 lines (33 loc) · 1.45 KB
/
mmd_code.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
## Code from : https://www.onurtunali.com/ml/2019/03/08/maximum-mean-discrepancy-in-machine-learning.html
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def MMD(x, y, kernel):
"""Emprical maximum mean discrepancy. The lower the result
the more evidence that distributions are the same.
Args:
x: first sample, distribution P
y: second sample, distribution Q
kernel: kernel type such as "multiscale" or "rbf"
"""
xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())
rx = (xx.diag().unsqueeze(0).expand_as(xx))
ry = (yy.diag().unsqueeze(0).expand_as(yy))
dxx = rx.t() + rx - 2. * xx # Used for A in (1)
dyy = ry.t() + ry - 2. * yy # Used for B in (1)
dxy = rx.t() + ry - 2. * zz # Used for C in (1)
XX, YY, XY = (torch.zeros(xx.shape).to(device),
torch.zeros(xx.shape).to(device),
torch.zeros(xx.shape).to(device))
if kernel == "multiscale":
bandwidth_range = [0.2, 0.5, 0.9, 1.3]
for a in bandwidth_range:
XX += a**2 * (a**2 + dxx)**-1
YY += a**2 * (a**2 + dyy)**-1
XY += a**2 * (a**2 + dxy)**-1
if kernel == "rbf":
bandwidth_range = [10, 15, 20, 50]
for a in bandwidth_range:
XX += torch.exp(-0.5*dxx/a)
YY += torch.exp(-0.5*dyy/a)
XY += torch.exp(-0.5*dxy/a)
return torch.mean(XX + YY - 2. * XY)