From b9d038835a6ac9bd5a37a88197bdda5502f49746 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Tue, 5 Mar 2024 13:07:36 +0800 Subject: [PATCH 1/7] update median add min mode --- python/paddle/tensor/stat.py | 67 ++++++++++++-- test/legacy_test/test_median.py | 152 +++++++++++++++++++++++++++++++- 2 files changed, 212 insertions(+), 7 deletions(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 697859fd82adda..1aef672c49cc92 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -352,7 +352,7 @@ def nanmedian(x, axis=None, keepdim=False, name=None): return out -def median(x, axis=None, keepdim=False, name=None): +def median(x, axis=None, keepdim=False, mode='avg', name=None): """ Compute the median along the specified axis. @@ -367,11 +367,18 @@ def median(x, axis=None, keepdim=False, name=None): the output Tensor is the same as ``x`` except in the reduced dimensions(it is of size 1 in this case). Otherwise, the shape of the output Tensor is squeezed in ``axis`` . Default is False. + mode (str, optional): Whether to use mean or min operation to calculate + the median value when the input tensor has an even number of elements + in the dimension ``axis``. Support 'avg' and 'min'. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor, results of median along ``axis`` of ``x``. If data type of ``x`` is float64, data type of results will be float64, otherwise data type will be float32. + ((Tensor, Tensor), optional), results of median along ``axis`` of ``x``. If ``mode`` is + 'min' and axis is not None, the result will be a tuple containing a tensor of median value and a tensor + of its indices. The data type of the indices will be int64. Otherwise the result will be the tensor of + median value. If data type of ``x`` is float64, data type of median value will be float64, otherwise + data type will be float32. Examples: .. code-block:: python @@ -405,6 +412,18 @@ def median(x, axis=None, keepdim=False, name=None): Tensor(shape=[1, 4], dtype=float32, place=Place(cpu), stop_gradient=True, [[4., 5., 6., 7.]]) + >>> y5 = paddle.median(x, mode='min') + >>> print(y5) + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, + 5.5) + + >>> median_value, median_indices = paddle.median(x, axis=1, mode='min') + >>> print(median_value) + Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, + [1., 5., 9.]) + >>> print(median_indices) + Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True, + [1, 1, 1]) """ if not isinstance(x, (Variable, paddle.pir.Value)): raise TypeError("In median, the input x should be a Tensor.") @@ -423,6 +442,9 @@ def median(x, axis=None, keepdim=False, name=None): ], 'when input 0-D, axis can only be [-1, 0] or default None' is_flatten = True + if mode not in ('avg', 'min'): + raise ValueError(f"Mode {mode} is not supported. Must be avg or min.") + inp_axis = axis if axis is None: is_flatten = True @@ -446,10 +468,27 @@ def median(x, axis=None, keepdim=False, name=None): else 'float32' ) if sz & 1 == 0: - out_tensor = paddle.slice( - tensor_topk, axes=[axis], starts=[kth - 1], ends=[kth] - ) + paddle.slice(tensor_topk, axes=[axis], starts=[kth], ends=[kth + 1]) - out_tensor = paddle.cast(out_tensor, dtype=dtype) / 2 + if mode == 'avg': + out_tensor = paddle.slice( + tensor_topk, axes=[axis], starts=[kth - 1], ends=[kth] + ) + paddle.slice( + tensor_topk, axes=[axis], starts=[kth], ends=[kth + 1] + ) + out_tensor = paddle.cast(out_tensor, dtype=dtype) / 2 + else: # mode == 'min' + out_tensor = paddle.cast( + paddle.slice( + tensor_topk, axes=[axis], starts=[kth - 1], ends=[kth] + ), + dtype=dtype, + ) + if inp_axis is not None: + out_idx = paddle.cast( + paddle.slice( + idx, axes=[axis], starts=[kth - 1], ends=[kth] + ), + dtype="int64", + ) else: out_tensor = paddle.cast( paddle.slice( @@ -457,6 +496,11 @@ def median(x, axis=None, keepdim=False, name=None): ), dtype=dtype, ) + if inp_axis is not None: + out_idx = paddle.cast( + paddle.slice(idx, axes=[axis], starts=[kth], ends=[kth + 1]), + dtype="int64", + ) out_tensor = out_tensor + paddle.sum( paddle.cast(paddle.isnan(x), dtype=dtype) * x, axis=axis, keepdim=True ) @@ -468,6 +512,17 @@ def median(x, axis=None, keepdim=False, name=None): else: if not keepdim: out_tensor = out_tensor.squeeze(axis) + + if mode == 'min' and inp_axis is not None: + if is_flatten: + if keepdim: + out_idx = out_idx.reshape([1] * dims) + else: + out_idx = out_idx.reshape([]) + else: + if not keepdim: + out_idx = out_idx.squeeze(axis) + return out_tensor, out_idx return out_tensor diff --git a/test/legacy_test/test_median.py b/test/legacy_test/test_median.py index 31750afe69fc56..485838cde2a6f6 100644 --- a/test/legacy_test/test_median.py +++ b/test/legacy_test/test_median.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import unittest import numpy as np @@ -22,7 +23,84 @@ DELTA = 1e-6 -class TestMedian(unittest.TestCase): +def np_medain_min(data, keepdims=False): + shape = data.shape + data_flat = data.flatten() + data_cnt = len(data_flat) + + data_flat[np.isnan(data_flat)] = np.inf + data_sort = np.sort(data_flat) + data_sort[np.isinf(data_sort)] = np.nan + + if data_cnt % 2: + is_odd = False + else: + is_odd = True + + i = int(data_cnt / 2) + if is_odd: + np_res = min(data_sort[i - 1], data_sort[i]) + else: + np_res = data_sort[i] + if keepdims: + new_shape = [1] * len(shape) + return np_res.reshape(new_shape) + else: + return np_res + + +def np_medain_min_axis(data, axis=None, keepdims=False): + data = copy.deepcopy(data) + if axis is None: + return np_medain_min(data, keepdims) + + axis = axis + len(data.shape) if axis < 0 else axis + trans_shape = [] + reshape = [] + for i in range(len(data.shape)): + if i != axis: + trans_shape.append(i) + reshape.append(data.shape[i]) + trans_shape.append(axis) + last_shape = data.shape[axis] + reshape.append(last_shape) + + data_flat = np.transpose(data, trans_shape) + + data_flat = np.reshape(data_flat, (-1, reshape[-1])) + + data_cnt = np.full( + shape=data_flat.shape[:-1], fill_value=data_flat.shape[-1] + ) + + data_flat[np.isnan(data_flat)] = np.inf + data_sort = np.sort(data_flat, axis=-1) + data_sort[np.isinf(data_sort)] = np.nan + + is_odd = data_cnt % 2 + + np_res = np.zeros(len(is_odd), dtype=data.dtype) + + for j in range(len(is_odd)): + if data_cnt[j] == 0: + np_res[j] = np.nan + continue + + i = int(data_cnt[j] / 2) + if is_odd[j]: + np_res[j] = data_sort[j, i] + else: + np_res[j] = min(data_sort[j, i - 1], data_sort[j, i]) + + if keepdims: + shape = list(data.shape) + shape[axis] = 1 + return np.reshape(np_res, shape) + else: + return np.reshape(np_res, reshape[:-1]) + + +class TestMedianAvg(unittest.TestCase): def check_numpy_res(self, np1, np2): self.assertEqual(np1.shape, np2.shape) mismatch = np.sum((np1 - np2) * (np1 - np2)) @@ -83,8 +161,80 @@ def test_median_exception(self): x = paddle.arange(12).reshape([3, 4]) self.assertRaises(ValueError, paddle.median, x, 1.0) self.assertRaises(ValueError, paddle.median, x, 2) + self.assertRaises(ValueError, paddle.median, x, 2, False, 'max') self.assertRaises(ValueError, paddle.median, paddle.to_tensor([])) +class TestMedianMin(unittest.TestCase): + def check_numpy_res(self, np1, np2): + self.assertEqual(np1.shape, np2.shape) + mismatch = np.sum((np1 - np2) * (np1 - np2)) + self.assertAlmostEqual(mismatch, 0, DELTA) + + def static_single_test_median(self, lis_test): + paddle.enable_static() + x, axis, keepdims = lis_test + res_np = np_medain_min_axis(x, axis=axis, keepdims=keepdims) + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + exe = paddle.static.Executor() + with paddle.static.program_guard(main_program, startup_program): + x_in = paddle.static.data(shape=x.shape, dtype=x.dtype, name='x') + y = paddle.median(x_in, axis, keepdims, mode='min') + [res_pd, _] = exe.run(feed={'x': x}, fetch_list=[y]) + self.check_numpy_res(res_pd, res_np) + paddle.disable_static() + + def dygraph_single_test_median(self, lis_test): + x, axis, keepdims = lis_test + res_np = np_medain_min_axis(x, axis=axis, keepdims=keepdims) + res_pd, _ = paddle.median( + paddle.to_tensor(x), axis, keepdims, mode='min' + ) + self.check_numpy_res(res_pd.numpy(False), res_np) + + @test_with_pir_api + def test_median_static(self): + h = 3 + w = 4 + l = 2 + x = np.arange(h * w * l).reshape([h, w, l]).astype("float32") + lis_tests = [ + [x, axis, keepdims] + for axis in [-1, 0, 1, 2] + for keepdims in [False, True] + ] + for lis_test in lis_tests: + self.static_single_test_median(lis_test) + + def test_median_dygraph(self): + paddle.disable_static() + h = 3 + w = 4 + l = 2 + x = np.arange(h * w * l).reshape([h, w, l]).astype("float32") + lis_tests = [ + [x, axis, keepdims] + for axis in [-1, 0, 1, 2] + for keepdims in [False, True] + ] + for lis_test in lis_tests: + self.dygraph_single_test_median(lis_test) + + def test_index_even_case(self): + paddle.disable_static() + x = paddle.arange(2 * 100).reshape((2, 100)).astype(paddle.float32) + out, index = paddle.median(x, axis=1, mode='min') + np.testing.assert_allclose(out.numpy(), [49.0, 149.0]) + np.testing.assert_equal(index.numpy(), [49, 49]) + + def test_index_odd_case(self): + paddle.disable_static() + x = paddle.arange(30).reshape((3, 10)).astype(paddle.float32) + out, index = paddle.median(x, axis=1, mode='min') + np.testing.assert_allclose(out.numpy(), [4.0, 14.0, 24.0]) + np.testing.assert_equal(index.numpy(), [4, 4, 4]) + + if __name__ == '__main__': unittest.main() From d9c372391abd79379bbd3c5302f2687969b38ab0 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Tue, 5 Mar 2024 16:42:23 +0800 Subject: [PATCH 2/7] update --- python/paddle/tensor/stat.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 1aef672c49cc92..df32f84575b01b 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -368,16 +368,16 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): dimensions(it is of size 1 in this case). Otherwise, the shape of the output Tensor is squeezed in ``axis`` . Default is False. mode (str, optional): Whether to use mean or min operation to calculate - the median value when the input tensor has an even number of elements + the median values when the input tensor has an even number of elements in the dimension ``axis``. Support 'avg' and 'min'. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: ((Tensor, Tensor), optional), results of median along ``axis`` of ``x``. If ``mode`` is - 'min' and axis is not None, the result will be a tuple containing a tensor of median value and a tensor + 'min' and axis is not None, the result will be a tuple containing a tensor of median values and a tensor of its indices. The data type of the indices will be int64. Otherwise the result will be the tensor of - median value. If data type of ``x`` is float64, data type of median value will be float64, otherwise + median values. If data type of ``x`` is float64, data type of median values will be float64, otherwise data type will be float32. Examples: @@ -514,14 +514,8 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): out_tensor = out_tensor.squeeze(axis) if mode == 'min' and inp_axis is not None: - if is_flatten: - if keepdim: - out_idx = out_idx.reshape([1] * dims) - else: - out_idx = out_idx.reshape([]) - else: - if not keepdim: - out_idx = out_idx.squeeze(axis) + if not keepdim: + out_idx = out_idx.squeeze(axis) return out_tensor, out_idx return out_tensor From 30c8d2a00fceca60a7cbeff4cd9d4aacda8ddac4 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Tue, 5 Mar 2024 22:32:44 +0800 Subject: [PATCH 3/7] update docs --- python/paddle/tensor/stat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index df32f84575b01b..3de5eb7f9d80e1 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -415,7 +415,7 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): >>> y5 = paddle.median(x, mode='min') >>> print(y5) Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, - 5.5) + 5.) >>> median_value, median_indices = paddle.median(x, axis=1, mode='min') >>> print(median_value) From c98d3d17a20b0661c3198f2bdddd38d622b31eaf Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Thu, 7 Mar 2024 10:50:18 +0800 Subject: [PATCH 4/7] update code --- python/paddle/tensor/stat.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 3de5eb7f9d80e1..d5bb13fab04519 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -369,16 +369,19 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): the output Tensor is squeezed in ``axis`` . Default is False. mode (str, optional): Whether to use mean or min operation to calculate the median values when the input tensor has an even number of elements - in the dimension ``axis``. Support 'avg' and 'min'. + in the dimension ``axis``. Support 'avg' and 'min'. Default is 'avg'. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - ((Tensor, Tensor), optional), results of median along ``axis`` of ``x``. If ``mode`` is - 'min' and axis is not None, the result will be a tuple containing a tensor of median values and a tensor - of its indices. The data type of the indices will be int64. Otherwise the result will be the tensor of - median values. If data type of ``x`` is float64, data type of median values will be float64, otherwise - data type will be float32. + ((Tensor, Tensor), optional) + If ``mode`` == 'avg', the result will be the tensor of median values; + If ``mode`` == 'min' and ``axis`` is None, the result will be the tensor of median values; + If ``mode`` == 'min' and ``axis`` is not None, the result will be a tuple of two tensors + containing median values and their indices. + + If data type of ``x`` is float64, data type of median values will be float64, otherwise + data type of median values will be float32. The data type of indices will be int64. Examples: .. code-block:: python @@ -444,7 +447,7 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): if mode not in ('avg', 'min'): raise ValueError(f"Mode {mode} is not supported. Must be avg or min.") - inp_axis = axis + need_idx = axis is not None if axis is None: is_flatten = True @@ -482,12 +485,9 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): ), dtype=dtype, ) - if inp_axis is not None: - out_idx = paddle.cast( - paddle.slice( - idx, axes=[axis], starts=[kth - 1], ends=[kth] - ), - dtype="int64", + if need_idx: + out_idx = paddle.slice( + idx, axes=[axis], starts=[kth - 1], ends=[kth] ) else: out_tensor = paddle.cast( @@ -496,11 +496,11 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): ), dtype=dtype, ) - if inp_axis is not None: - out_idx = paddle.cast( - paddle.slice(idx, axes=[axis], starts=[kth], ends=[kth + 1]), - dtype="int64", + if need_idx: + out_idx = paddle.slice( + idx, axes=[axis], starts=[kth], ends=[kth + 1] ) + out_tensor = out_tensor + paddle.sum( paddle.cast(paddle.isnan(x), dtype=dtype) * x, axis=axis, keepdim=True ) @@ -513,7 +513,7 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): if not keepdim: out_tensor = out_tensor.squeeze(axis) - if mode == 'min' and inp_axis is not None: + if mode == 'min' and need_idx: if not keepdim: out_idx = out_idx.squeeze(axis) return out_tensor, out_idx From 3851ade1015d2fc2745e92dfe7679cf829b35d5f Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Thu, 7 Mar 2024 13:13:59 +0800 Subject: [PATCH 5/7] update cast --- python/paddle/tensor/stat.py | 42 ++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index d5bb13fab04519..fc7ccd17add485 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -380,8 +380,10 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): If ``mode`` == 'min' and ``axis`` is not None, the result will be a tuple of two tensors containing median values and their indices. - If data type of ``x`` is float64, data type of median values will be float64, otherwise - data type of median values will be float32. The data type of indices will be int64. + When ``mode`` == 'avg', if data type of ``x`` is float64, data type of median values will be float64, + otherwise data type of median values will be float32. + When ``mode`` == 'min', the data type of median values will be the same as ``x``. The data type of + indices will be int64. Examples: .. code-block:: python @@ -417,13 +419,13 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): >>> y5 = paddle.median(x, mode='min') >>> print(y5) - Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, - 5.) + Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True, + 5) >>> median_value, median_indices = paddle.median(x, axis=1, mode='min') >>> print(median_value) - Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, - [1., 5., 9.]) + Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True, + [1, 5, 9]) >>> print(median_indices) Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True, [1, 1, 1]) @@ -470,36 +472,38 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): in [core.VarDesc.VarType.FP64, paddle.base.core.DataType.FLOAT64] else 'float32' ) - if sz & 1 == 0: - if mode == 'avg': + if mode == 'avg': + if sz & 1 == 0: out_tensor = paddle.slice( tensor_topk, axes=[axis], starts=[kth - 1], ends=[kth] ) + paddle.slice( tensor_topk, axes=[axis], starts=[kth], ends=[kth + 1] ) out_tensor = paddle.cast(out_tensor, dtype=dtype) / 2 - else: # mode == 'min' + else: out_tensor = paddle.cast( paddle.slice( - tensor_topk, axes=[axis], starts=[kth - 1], ends=[kth] + tensor_topk, axes=[axis], starts=[kth], ends=[kth + 1] ), dtype=dtype, ) + else: # mode == 'min' + if sz & 1 == 0: + out_tensor = paddle.slice( + tensor_topk, axes=[axis], starts=[kth - 1], ends=[kth] + ) if need_idx: out_idx = paddle.slice( idx, axes=[axis], starts=[kth - 1], ends=[kth] ) - else: - out_tensor = paddle.cast( - paddle.slice( + else: + out_tensor = paddle.slice( tensor_topk, axes=[axis], starts=[kth], ends=[kth + 1] - ), - dtype=dtype, - ) - if need_idx: - out_idx = paddle.slice( - idx, axes=[axis], starts=[kth], ends=[kth + 1] ) + if need_idx: + out_idx = paddle.slice( + idx, axes=[axis], starts=[kth], ends=[kth + 1] + ) out_tensor = out_tensor + paddle.sum( paddle.cast(paddle.isnan(x), dtype=dtype) * x, axis=axis, keepdim=True From dd3aa4814d3ae90bbe7cff37a9a3a48be02ecff1 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Fri, 8 Mar 2024 12:32:22 +0800 Subject: [PATCH 6/7] update test --- test/legacy_test/test_median.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/test/legacy_test/test_median.py b/test/legacy_test/test_median.py index 485838cde2a6f6..ee38ef57f79c9e 100644 --- a/test/legacy_test/test_median.py +++ b/test/legacy_test/test_median.py @@ -166,11 +166,6 @@ def test_median_exception(self): class TestMedianMin(unittest.TestCase): - def check_numpy_res(self, np1, np2): - self.assertEqual(np1.shape, np2.shape) - mismatch = np.sum((np1 - np2) * (np1 - np2)) - self.assertAlmostEqual(mismatch, 0, DELTA) - def static_single_test_median(self, lis_test): paddle.enable_static() x, axis, keepdims = lis_test @@ -182,7 +177,7 @@ def static_single_test_median(self, lis_test): x_in = paddle.static.data(shape=x.shape, dtype=x.dtype, name='x') y = paddle.median(x_in, axis, keepdims, mode='min') [res_pd, _] = exe.run(feed={'x': x}, fetch_list=[y]) - self.check_numpy_res(res_pd, res_np) + np.testing.assert_allclose(res_pd, res_np) paddle.disable_static() def dygraph_single_test_median(self, lis_test): @@ -191,7 +186,7 @@ def dygraph_single_test_median(self, lis_test): res_pd, _ = paddle.median( paddle.to_tensor(x), axis, keepdims, mode='min' ) - self.check_numpy_res(res_pd.numpy(False), res_np) + np.testing.assert_allclose(res_pd.numpy(False), res_np) @test_with_pir_api def test_median_static(self): From 1afb44a9f74e45f12e9e35cf17f8d924803e7666 Mon Sep 17 00:00:00 2001 From: zachary sun <70642955+sunzhongkai588@users.noreply.github.com> Date: Tue, 12 Mar 2024 16:47:36 +0800 Subject: [PATCH 7/7] Update python/paddle/tensor/stat.py --- python/paddle/tensor/stat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index fc7ccd17add485..dc5fa034c88547 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -374,7 +374,7 @@ def median(x, axis=None, keepdim=False, mode='avg', name=None): For more information, please refer to :ref:`api_guide_Name`. Returns: - ((Tensor, Tensor), optional) + Tensor or tuple of Tensor. If ``mode`` == 'avg', the result will be the tensor of median values; If ``mode`` == 'min' and ``axis`` is None, the result will be the tensor of median values; If ``mode`` == 'min' and ``axis`` is not None, the result will be a tuple of two tensors