From fb2f862dd7ae9635b797694eee8337d8ca6e97fb Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Tue, 30 Apr 2024 05:19:15 +0800 Subject: [PATCH] do not require `bk_flow_mask_fn` for all distributions --- src/pyjuice/layer/input_layer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/pyjuice/layer/input_layer.py b/src/pyjuice/layer/input_layer.py index fd346aa7..9db29bf3 100644 --- a/src/pyjuice/layer/input_layer.py +++ b/src/pyjuice/layer/input_layer.py @@ -393,9 +393,8 @@ def backward(self, data: torch.Tensor, node_flows: torch.Tensor, ) # Handle the masked input nodes - if missing_mask is not None: + if missing_mask is not None and self.bk_flow_mask_fn is not None: if not self.provided("_flows_mask_kernel"): - assert self.bk_flow_mask_fn is not None, f"`bk_flow_mask_fn` is not implemented for distribution {type(self.dist)}." self._flows_mask_kernel = self._compile_triton_kernel(self._flows_kernel_template, flow_fn = self.bk_flow_mask_fn) self._flows_mask_kernel[grid](