Skip to content

Commit

Permalink
extract_bboxes_pt21.py add separate config parameter
Browse files Browse the repository at this point in the history
so values don't get overridden as much

Signed-off-by: Martin <[email protected]>
  • Loading branch information
bmmtstb committed Jun 28, 2024
1 parent 09bb620 commit d320d04
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions scripts/helpers/extract_bboxes_pt21.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def save_crops(_s: State, img_dir: FilePath, _gt_img_id: str | int) -> None:
torch.save(_s.keypoints_local[i].unsqueeze(0).cpu(), str(img_path).replace(".jpg", ".pt"))


def predict_and_save_rcnn(dl_key: str, subm_key: str, rcnn_cfg_str: str) -> None:
def predict_and_save_rcnn(config: Config, dl_key: str, subm_key: str, rcnn_cfg_str: str) -> None:
"""Predict and save the rcnn results of all the PT21 datasets in the folder given by the config."""
# pylint: disable=too-many-locals

Expand Down Expand Up @@ -158,7 +158,7 @@ def predict_and_save_rcnn(dl_key: str, subm_key: str, rcnn_cfg_str: str) -> None
subm_module.save()


def extract_gt_boxes(dl_key: str) -> None:
def extract_gt_boxes(config: Config, dl_key: str) -> None:
"""Given the gt annotations, extract the image crops and local coordinates."""
dataset_paths: list[str] = glob(config[dl_key]["dataset_paths"])

Expand Down Expand Up @@ -193,28 +193,29 @@ def extract_gt_boxes(dl_key: str) -> None:
if __name__ == "__main__":
print(f"Cuda available: {torch.cuda.is_available()}")

config: Config = load_config(CONFIG_FILE)

for DL_KEY in DL_KEYS:
print(f"Extracting ground-truth: {DL_KEY}")
extract_gt_boxes(dl_key=DL_KEY)
cfg: Config = load_config(CONFIG_FILE)
extract_gt_boxes(config=cfg, dl_key=DL_KEY)

for RCNN_DL_KEY in RCNN_DL_KEYS:
print(f"Extracting rcnn: {RCNN_DL_KEY}")

rcnn_cfg: Config = load_config(CONFIG_FILE)
for score_threshold in (pbar_score_thresh := tqdm(SCORE_THRESHS, desc="Score-Threshold")):
pbar_score_thresh.set_postfix_str(str(score_threshold))

score_str = f"{int(score_threshold * 100):03d}"
config[RCNN_DL_KEY]["score_threshold"] = score_threshold
rcnn_cfg[RCNN_DL_KEY]["score_threshold"] = score_threshold

for iou_threshold in (pbar_iou_thresh := tqdm(IOU_THRESHS, desc="IoU-Threshold")):
pbar_iou_thresh.set_postfix_str(str(iou_threshold))

iou_str = f"{int(iou_threshold * 100):03d}"
config[RCNN_DL_KEY]["iou_threshold"] = iou_threshold
rcnn_cfg[RCNN_DL_KEY]["iou_threshold"] = iou_threshold

predict_and_save_rcnn(
config=rcnn_cfg,
dl_key=RCNN_DL_KEY,
subm_key=SUBM_KEY,
rcnn_cfg_str=f"rcnn_{score_str}_{iou_str}_{RCNN_DL_KEY.rsplit('_', maxsplit=1)[-1]}",
Expand Down

0 comments on commit d320d04

Please sign in to comment.