Skip to content

Commit

Permalink
Add option to use MMD variance ratio loss
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasburian committed Oct 14, 2024
1 parent c0efdbc commit 81f08fb
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions DaNN/mmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@



def mmd_loss(x_src, x_tar,gamma=10 ^ 3):
def mmd_loss(x_src, x_tar, gamma=10 ^ 3, use_var=False):
if use_var:
return mix_rbf_mmd2_and_ratio(x_src, x_tar, [gamma])
return mix_rbf_mmd2(x_src, x_tar, [gamma])


Expand Down Expand Up @@ -78,7 +80,7 @@ def mix_rbf_mmd2(X, Y, sigma_list, biased=True):
def mix_rbf_mmd2_and_ratio(X, Y, sigma_list, biased=True):
K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list)
# return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased)
return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased)
return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased)[0]


################################################################################
Expand Down

0 comments on commit 81f08fb

Please sign in to comment.