From 81f08fbcb0512907b6d3786ef4aebc1e1057fbc8 Mon Sep 17 00:00:00 2001 From: Jonas Burian Date: Mon, 14 Oct 2024 19:05:08 +0300 Subject: [PATCH] Add option to use MMD variance ratio loss --- DaNN/mmd.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/DaNN/mmd.py b/DaNN/mmd.py index e93b5e3..081c66d 100644 --- a/DaNN/mmd.py +++ b/DaNN/mmd.py @@ -8,7 +8,9 @@ -def mmd_loss(x_src, x_tar,gamma=10 ^ 3): +def mmd_loss(x_src, x_tar, gamma=10 ^ 3, use_var=False): + if use_var: + return mix_rbf_mmd2_and_ratio(x_src, x_tar, [gamma]) return mix_rbf_mmd2(x_src, x_tar, [gamma]) @@ -78,7 +80,7 @@ def mix_rbf_mmd2(X, Y, sigma_list, biased=True): def mix_rbf_mmd2_and_ratio(X, Y, sigma_list, biased=True): K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list) # return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased) - return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased) + return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased)[0] ################################################################################