diff --git a/test/test_distributed.py b/test/test_distributed.py index d04b382c54268c..d6ce0cb3ae8040 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -1362,9 +1362,12 @@ def _join_and_reduce(self, fn): getattr(fn, "skip_if_no_gpu", False) or getattr(fn, "skip_if_small_worldsize", False) ) - self.JOIN_TIMEOUT = get_timeout(self.id()) - for p in self.processes: - p.join(self.JOIN_TIMEOUT) + join_timeout = get_timeout(self.id()) + for rank, process in enumerate(self.processes): + process.join(join_timeout) + self.assertFalse( + process.is_alive(), + "Timeout waiting for rank %d to terminate" % rank) first_process = self.processes[0] for p in self.processes: diff --git a/test/test_thd_distributed.py b/test/test_thd_distributed.py index 25072f121c6e13..32438a1602f5cd 100644 --- a/test/test_thd_distributed.py +++ b/test/test_thd_distributed.py @@ -1110,9 +1110,12 @@ def _join_and_reduce(self, fn): getattr(fn, "skip_if_no_gpu", False) or getattr(fn, "skip_if_small_worldsize", False) ) - self.JOIN_TIMEOUT = get_timeout(self.id()) - for p in self.processes: - p.join(self.JOIN_TIMEOUT) + join_timeout = get_timeout(self.id()) + for rank, process in enumerate(self.processes): + process.join(join_timeout) + self.assertFalse( + process.is_alive(), + "Timeout waiting for rank %d to terminate" % rank) first_process = self.processes[0] for p in self.processes: