diff --git a/mmseg/datasets/custom.py b/mmseg/datasets/custom.py index 3eeb0dda80..fe29dcf9c0 100644 --- a/mmseg/datasets/custom.py +++ b/mmseg/datasets/custom.py @@ -349,7 +349,16 @@ def get_palette_for_custom_classes(self, class_names, palette=None): elif palette is None: if self.PALETTE is None: + # Get random state before set seed, and restore + # random state later. + # It will prevent loss of randomness, as the palette + # may be different in each iteration if not specified. + # See: https://github.com/open-mmlab/mmdetection/issues/5844 + state = np.random.get_state() + np.random.seed(42) + # random palette palette = np.random.randint(0, 255, size=(len(class_names), 3)) + np.random.set_state(state) else: palette = self.PALETTE diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index f0f320ffbf..9b22a7ca9b 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -245,8 +245,17 @@ def show_result(self, seg = result[0] if palette is None: if self.PALETTE is None: + # Get random state before set seed, + # and restore random state later. + # It will prevent loss of randomness, as the palette + # may be different in each iteration if not specified. + # See: https://github.com/open-mmlab/mmdetection/issues/5844 + state = np.random.get_state() + np.random.seed(42) + # random palette palette = np.random.randint( 0, 255, size=(len(self.CLASSES), 3)) + np.random.set_state(state) else: palette = self.PALETTE palette = np.array(palette)