From f4efdd56ce108226e9fc584e7e046e32d58ed1e3 Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Tue, 13 Aug 2024 11:45:19 -0400 Subject: [PATCH] address issue with clamp_min on MPS --- tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py b/tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py index eb64518d..cb03f5fa 100644 --- a/tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py +++ b/tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py @@ -982,7 +982,7 @@ def forward( `(batch_size, 1, num_input_channels)`) """ denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim) - denominator = denominator.clamp_min(1.0) + denominator = denominator.clamp_min(torch.tensor(1, device=denominator.device)) loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator