Skip to content

Commit

Permalink
RF softmax use_mask option
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Apr 18, 2023
1 parent a6d5ef6 commit 2ae0c92
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 14 deletions.
6 changes: 4 additions & 2 deletions returnn/frontend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,19 +432,21 @@ def activation_raw(raw_tensor: T, func: str) -> T:
raise NotImplementedError

@staticmethod
def softmax(tensor: Tensor, *, axis: Dim) -> Tensor:
def softmax(tensor: Tensor, *, axis: Dim, use_mask: bool = True) -> Tensor:
"""
:param tensor:
:param axis:
:param use_mask:
:return: softmax over axis
"""
raise NotImplementedError

@staticmethod
def log_softmax(tensor: Tensor, *, axis: Dim) -> Tensor:
def log_softmax(tensor: Tensor, *, axis: Dim, use_mask: bool = True) -> Tensor:
"""
:param tensor:
:param axis:
:param use_mask:
:return: log_softmax over axis
"""
raise NotImplementedError
Expand Down
8 changes: 4 additions & 4 deletions returnn/frontend/math_.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,13 +434,13 @@ def silu(a: Tensor) -> Tensor:
swish = silu # alias


def softmax(a: Tensor, *, axis: Dim) -> Tensor:
def softmax(a: Tensor, *, axis: Dim, use_mask: bool = True) -> Tensor:
"""softmax"""
# noinspection PyProtectedMember
return a._raw_backend.softmax(a, axis=axis)
return a._raw_backend.softmax(a, axis=axis, use_mask=use_mask)


def log_softmax(a: Tensor, *, axis: Dim) -> Tensor:
def log_softmax(a: Tensor, *, axis: Dim, use_mask: bool = True) -> Tensor:
"""log_softmax"""
# noinspection PyProtectedMember
return a._raw_backend.log_softmax(a, axis=axis)
return a._raw_backend.log_softmax(a, axis=axis, use_mask=use_mask)
15 changes: 11 additions & 4 deletions returnn/tf/frontend_layers/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,22 @@ def activation_raw(raw_tensor: Layer, func: str) -> Layer:
).raw_tensor

@staticmethod
def softmax(tensor: Tensor, *, axis: Dim) -> Tensor:
def softmax(tensor: Tensor, *, axis: Dim, use_mask: bool = True) -> Tensor:
"""softmax"""
return rfl.make_layer({"class": "softmax_over_spatial", "axis": axis, "from": tensor}, name="softmax")
args = {}
if not use_mask:
args["use_time_mask"] = False
return rfl.make_layer({"class": "softmax_over_spatial", "axis": axis, "from": tensor, **args}, name="softmax")

@staticmethod
def log_softmax(tensor: Tensor, *, axis: Dim) -> Tensor:
def log_softmax(tensor: Tensor, *, axis: Dim, use_mask: bool = True) -> Tensor:
"""log softmax"""
args = {}
if not use_mask:
args["use_time_mask"] = False
return rfl.make_layer(
{"class": "softmax_over_spatial", "axis": axis, "from": tensor, "log_space": True}, name="log_softmax"
{"class": "softmax_over_spatial", "axis": axis, "from": tensor, "log_space": True, **args},
name="log_softmax",
)

@staticmethod
Expand Down
10 changes: 6 additions & 4 deletions returnn/torch/frontend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,14 +367,15 @@ def activation_raw(raw_tensor: torch.Tensor, func: str) -> torch.Tensor:
return f(raw_tensor)

@staticmethod
def softmax(tensor: Tensor, *, axis: Dim) -> Tensor:
def softmax(tensor: Tensor, *, axis: Dim, use_mask: bool = True) -> Tensor:
"""
:param tensor:
:param axis:
:param use_mask:
:return: softmax over axis
"""
out = tensor.copy_template("softmax")
if axis.need_masking():
if use_mask and axis.need_masking():
tensor = tensor.copy()
mask = tensor.get_sequence_mask_broadcast(axis=axis)
inf_value = get_global_inf_value()
Expand All @@ -383,14 +384,15 @@ def softmax(tensor: Tensor, *, axis: Dim) -> Tensor:
return out

@staticmethod
def log_softmax(tensor: Tensor, *, axis: Dim) -> Tensor:
def log_softmax(tensor: Tensor, *, axis: Dim, use_mask: bool = True) -> Tensor:
"""
:param tensor:
:param axis:
:param use_mask:
:return: log_softmax over axis
"""
out = tensor.copy_template("log_softmax")
if axis.need_masking():
if use_mask and axis.need_masking():
tensor = tensor.copy()
mask = tensor.get_sequence_mask_broadcast(axis=axis)
inf_value = get_global_inf_value()
Expand Down

0 comments on commit 2ae0c92

Please sign in to comment.