diff --git a/dgs/models/dataset/keypoint_rcnn.py b/dgs/models/dataset/keypoint_rcnn.py index e112804..5ac0d38 100644 --- a/dgs/models/dataset/keypoint_rcnn.py +++ b/dgs/models/dataset/keypoint_rcnn.py @@ -7,6 +7,7 @@ import os from abc import ABC +from typing import Union import torch as t from imagesize import imagesize @@ -243,7 +244,7 @@ class KeypointRCNNImageBackbone(KeypointRCNNBackbone, ImageDataset): __doc__ += KeypointRCNNBackbone.__doc__ data: list[FilePath] - masks: list[tvte.Mask | None] + masks: list[Union[tvte.Mask, None]] def __init__(self, config: Config, path: NodePath) -> None: KeypointRCNNBackbone.__init__(self, config=config, path=path)