diff --git a/returnn/frontend/_backend.py b/returnn/frontend/_backend.py index 5cb29f651..15dc77350 100644 --- a/returnn/frontend/_backend.py +++ b/returnn/frontend/_backend.py @@ -432,19 +432,21 @@ def activation_raw(raw_tensor: T, func: str) -> T: raise NotImplementedError @staticmethod - def softmax(tensor: Tensor, *, axis: Dim) -> Tensor: + def softmax(tensor: Tensor, *, axis: Dim, use_mask: bool = True) -> Tensor: """ :param tensor: :param axis: + :param use_mask: :return: softmax over axis """ raise NotImplementedError @staticmethod - def log_softmax(tensor: Tensor, *, axis: Dim) -> Tensor: + def log_softmax(tensor: Tensor, *, axis: Dim, use_mask: bool = True) -> Tensor: """ :param tensor: :param axis: + :param use_mask: :return: log_softmax over axis """ raise NotImplementedError diff --git a/returnn/frontend/math_.py b/returnn/frontend/math_.py index 761aace94..d6139ae95 100644 --- a/returnn/frontend/math_.py +++ b/returnn/frontend/math_.py @@ -434,13 +434,13 @@ def silu(a: Tensor) -> Tensor: swish = silu # alias -def softmax(a: Tensor, *, axis: Dim) -> Tensor: +def softmax(a: Tensor, *, axis: Dim, use_mask: bool = True) -> Tensor: """softmax""" # noinspection PyProtectedMember - return a._raw_backend.softmax(a, axis=axis) + return a._raw_backend.softmax(a, axis=axis, use_mask=use_mask) -def log_softmax(a: Tensor, *, axis: Dim) -> Tensor: +def log_softmax(a: Tensor, *, axis: Dim, use_mask: bool = True) -> Tensor: """log_softmax""" # noinspection PyProtectedMember - return a._raw_backend.log_softmax(a, axis=axis) + return a._raw_backend.log_softmax(a, axis=axis, use_mask=use_mask) diff --git a/returnn/tf/frontend_layers/_backend.py b/returnn/tf/frontend_layers/_backend.py index d9225e0d2..c07f9ae57 100644 --- a/returnn/tf/frontend_layers/_backend.py +++ b/returnn/tf/frontend_layers/_backend.py @@ -292,15 +292,22 @@ def activation_raw(raw_tensor: Layer, func: str) -> Layer: ).raw_tensor @staticmethod - def softmax(tensor: Tensor, *, axis: Dim) -> Tensor: + def softmax(tensor: Tensor, *, axis: Dim, use_mask: bool = True) -> Tensor: """softmax""" - return rfl.make_layer({"class": "softmax_over_spatial", "axis": axis, "from": tensor}, name="softmax") + args = {} + if not use_mask: + args["use_time_mask"] = False + return rfl.make_layer({"class": "softmax_over_spatial", "axis": axis, "from": tensor, **args}, name="softmax") @staticmethod - def log_softmax(tensor: Tensor, *, axis: Dim) -> Tensor: + def log_softmax(tensor: Tensor, *, axis: Dim, use_mask: bool = True) -> Tensor: """log softmax""" + args = {} + if not use_mask: + args["use_time_mask"] = False return rfl.make_layer( - {"class": "softmax_over_spatial", "axis": axis, "from": tensor, "log_space": True}, name="log_softmax" + {"class": "softmax_over_spatial", "axis": axis, "from": tensor, "log_space": True, **args}, + name="log_softmax", ) @staticmethod diff --git a/returnn/torch/frontend/_backend.py b/returnn/torch/frontend/_backend.py index c8dd4059b..ea6daf952 100644 --- a/returnn/torch/frontend/_backend.py +++ b/returnn/torch/frontend/_backend.py @@ -367,14 +367,15 @@ def activation_raw(raw_tensor: torch.Tensor, func: str) -> torch.Tensor: return f(raw_tensor) @staticmethod - def softmax(tensor: Tensor, *, axis: Dim) -> Tensor: + def softmax(tensor: Tensor, *, axis: Dim, use_mask: bool = True) -> Tensor: """ :param tensor: :param axis: + :param use_mask: :return: softmax over axis """ out = tensor.copy_template("softmax") - if axis.need_masking(): + if use_mask and axis.need_masking(): tensor = tensor.copy() mask = tensor.get_sequence_mask_broadcast(axis=axis) inf_value = get_global_inf_value() @@ -383,14 +384,15 @@ def softmax(tensor: Tensor, *, axis: Dim) -> Tensor: return out @staticmethod - def log_softmax(tensor: Tensor, *, axis: Dim) -> Tensor: + def log_softmax(tensor: Tensor, *, axis: Dim, use_mask: bool = True) -> Tensor: """ :param tensor: :param axis: + :param use_mask: :return: log_softmax over axis """ out = tensor.copy_template("log_softmax") - if axis.need_masking(): + if use_mask and axis.need_masking(): tensor = tensor.copy() mask = tensor.get_sequence_mask_broadcast(axis=axis) inf_value = get_global_inf_value()