Skip to content

Commit

Permalink
Finalized function that extracts all bounding-boxes of the PT21 dataset
Browse files Browse the repository at this point in the history
posetrack.py extract_all_bboxes()
PoseTrack21Torchreid Model does not use this yet, WIP

Signed-off-by: Martin <[email protected]>
  • Loading branch information
bmmtstb committed Jan 6, 2024
1 parent b05ae65 commit 3335ae4
Showing 1 changed file with 32 additions and 28 deletions.
60 changes: 32 additions & 28 deletions dgs/models/dataset/posetrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def validate_pt21_json(json: dict) -> None:

def extract_all_bboxes(
base_dataset_path: FilePath = "./data/PoseTrack21/",
data_dir: FilePath = "./posetrack_data_fast/",
anno_dir: FilePath = "./posetrack_data/",
crop_size: ImgShape = (256, 256),
transform_mode: str = "zero-pad",
**kwargs,
Expand All @@ -67,7 +67,7 @@ def extract_all_bboxes(
Args:
base_dataset_path (FilePath): The path to the |PT21| dataset directory.
data_dir (FilePath): The name of the directory containing the folders for the training and test annotations.
anno_dir (FilePath): The name of the directory containing the folders for the training and test annotations.
crop_size (ImgShape): The target shape of the image crops.
transform_mode (str): Defines the resize mode, has to be in the modes of
:class:`~dgs.utils.image.CustomToAspect`. Default "zero-pad".
Expand All @@ -77,15 +77,15 @@ def extract_all_bboxes(
quality (int): The quality to save the jpegs as. Default 90. Default of torchvision is 75.
check_img_sizes (bool): Whether to check if all images in a given folder have the same size before stacking them
for cropping. Default False.
load_image (dict[str, any]): additional kwargs passed to load_image() function. Default {}.
"""
# pylint: disable=too-many-locals

base_dataset_path = to_abspath(base_dataset_path)
annos_path = to_abspath(os.path.join(base_dataset_path, data_dir))
crops_path = os.path.join(base_dataset_path, "crops")

# extract kwargs
device: Device = kwargs.pop("device", "cuda" if torch.cuda.is_available() else "cpu")
quality: int = kwargs.pop("quality", 90)
check_img_sizes: bool = kwargs.pop("check_img_sizes", False)

mkdir_if_missing(crops_path)

Expand All @@ -97,7 +97,7 @@ def extract_all_bboxes(
]
)

for abs_anno_path, _, files in tqdm(os.walk(annos_path), desc="datasets", position=0):
for abs_anno_path, _, files in tqdm(os.walk(os.path.join(base_dataset_path, anno_dir)), desc="annos", position=0):
# skip directories that don't contain files, e.g., the folder containing the datasets
if len(files) == 0:
continue
Expand All @@ -106,31 +106,30 @@ def extract_all_bboxes(
# abs_anno_path => .../PoseTrack21/images/{train}/{dataset_name}.json

# create folder {train} inside crops
train_folder_name = abs_anno_path.split("/")[-1]
crops_train_dir = os.path.join(crops_path, train_folder_name)
crops_train_dir = os.path.join(crops_path, abs_anno_path.split("/")[-1])
mkdir_if_missing(crops_train_dir)

for anno_file in tqdm(files, desc="annotation-files", position=1):
if not anno_file.endswith(".json"):
continue
try:
json = read_json(os.path.join(abs_anno_path, anno_file))
validate_pt21_json(json)
except Exception as e:
warnings.warn(str(e))
continue

# load and validate json
json = read_json(os.path.join(abs_anno_path, anno_file))
validate_pt21_json(json)

# get the folder name which is the name of the sub-dataset and create the folder within crops
dataset_name = anno_file.split(".")[0]
crops_subset_path = os.path.join(crops_train_dir, dataset_name)
crops_subset_path = os.path.join(crops_train_dir, anno_file.split(".")[0])
mkdir_if_missing(crops_subset_path)

# skip if the folder has the correct number of files
if len(os.listdir(crops_subset_path)) == len(json["annotations"]):
continue

# check that the image sizes in every folder match
if check_img_sizes and len(set(imagesize.get(os.path.join(crops_subset_path, f)) for f in files)) != 1:
if (
kwargs.get("check_img_sizes", False)
and len(set(imagesize.get(os.path.join(crops_subset_path, f)) for f in files)) != 1
):
warnings.warn(f"In folder {crops_subset_path} the images do not have the same size.")

# Because the images in every folder have the same shape,
Expand All @@ -150,25 +149,30 @@ def extract_all_bboxes(
new_img_name = f"{anno['image_id']}_{str(anno['person_id'])}.jpg"
new_fps.append(os.path.join(crops_subset_path, new_img_name))

del json

imgs: TVImage = load_image(
filepath=tuple(img_fps),
device=device,
requires_grad=False,
**kwargs.get("load_image", {}),
)

data = {
"image": imgs,
"box": tv_tensors.BoundingBoxes(
torch.stack(boxes), format="XYWH", canvas_size=imgs.shape[-2:], device=device
),
"keypoints": torch.zeros((imgs.shape[-4], 1, 2), device=device),
"mode": transform_mode,
"output_size": crop_size,
}
crops = transform(data)["image"].cpu()
# pass original images through CustomResizeCrop transform and get the resulting image crops
crops = transform(
{
"image": imgs,
"box": tv_tensors.BoundingBoxes(
torch.stack(boxes), format="XYWH", canvas_size=imgs.shape[-2:], device=device
),
"keypoints": torch.zeros((imgs.shape[-4], 1, 2), device=device),
"mode": transform_mode,
"output_size": crop_size,
}
)["image"]

for fp, crop in zip(new_fps, crops):
write_jpeg(input=crop, filename=fp, quality=quality)
write_jpeg(input=crop, filename=fp, quality=kwargs.get("quality", 90))


def get_pose_track_21(config: Config, path: NodePath) -> TorchDataset:
Expand Down

0 comments on commit 3335ae4

Please sign in to comment.