Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
bmmtstb committed Sep 12, 2024
1 parent 20b6762 commit 9f36d1e
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 55 deletions.
2 changes: 1 addition & 1 deletion dgs/models/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def train_model(self) -> optim.Optimizer:
Returns:
The current optimizer after training.
"""
# pylint: disable=too-many-statements
# pylint: disable=too-many-statements,too-many-branches

if self.train_dl is None:
raise ValueError("No DataLoader for the Training data was given. Can't continue.")
Expand Down
47 changes: 27 additions & 20 deletions dgs/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def image_crop(self) -> Image:
If the crops are not available, try to load them using :func:`load_image_crop` and :attr:`crop_path`.
"""
if "image_crop" not in self.data:
return self.load_image_crop()
return self.load_image_crop(store=False)
return self.data["image_crop"]

@image_crop.setter
Expand Down Expand Up @@ -576,6 +576,14 @@ def load_image_crop(self, store: bool = False, **kwargs) -> Image:
crop_size: The size of the image crops.
Default ``DEF_VAL.images.crop_size``.
"""

def save_values(crop_: Image, loc_kps_: Union[t.Tensor, None]) -> None:
if store:
self.data["image_crop"] = crop_
if loc_kps_ is not None:
self.data["keypoints_local"] = loc_kps_

# data already exists
if (
"image_crop" in self
and self.data["image_crop"] is not None
Expand All @@ -584,38 +592,37 @@ def load_image_crop(self, store: bool = False, **kwargs) -> Image:
):
return self.image_crop

# empty state
if self.B == 0:
crop = t.empty((0, 3, 0, 0), device=self.device, dtype=t.long)
if store:
self.data["image_crop"] = crop
save_values(crop_=crop, loc_kps_=None)
return crop

# load from crop path
if "crop_path" in self:
if len(self.crop_path) == 0:
crop = []
loc_kps = t.empty((0, 1, 2), dtype=t.long, device=self.device)
else:
# allow changing the crop_size and other params via kwargs
crop = load_image(filepath=self.crop_path, device=self.device, **kwargs)
assert (
len(self.crop_path) > 0
), f"expected to have at least one entry in crop_path, got: {len(self.crop_path)}"

kps_paths = tuple(replace_file_type(sub_path, new_type=".pt") for sub_path in self.crop_path)
if all(is_file(path) for path in kps_paths):
loc_kps = self.keypoints_and_weights_from_paths(kps_paths, save_weights=store)
else:
loc_kps = None
if store:
self.data["image_crop"] = crop
if loc_kps is not None:
self.data["keypoints_local"] = loc_kps
# allow changing the crop_size and other params via kwargs
crop = load_image(filepath=self.crop_path, device=self.device, **kwargs)
kps_paths = tuple(replace_file_type(sub_path, new_type=".pt") for sub_path in self.crop_path)

if all(is_file(path) for path in kps_paths):
loc_kps = self.keypoints_and_weights_from_paths(kps_paths, save_weights=store)
else:
loc_kps = None
save_values(crop_=crop, loc_kps_=loc_kps)
return crop

# try to extract using image and bbox
try:
kps = self.keypoints if "keypoints" in self.data else None
crop, loc_kps = extract_crops_from_images(imgs=self.image, bboxes=self.bbox, kps=kps, **kwargs)
if store:
self.image_crop = crop
self.data["image_crop"] = crop
if kps is not None:
self.keypoints_local = loc_kps
self.data["keypoints_local"] = loc_kps
return crop
except AttributeError as e:
raise AttributeError(
Expand Down
66 changes: 32 additions & 34 deletions scripts/own/train_dynamic_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@
# @MemoryTracker(interval=7.5, top_n=20)
@notify_on_completion_or_error(min_time=30, info="run initial weight")
@t.no_grad()
def test_pt21(config: Config, dl_key: str, paths: list, out_key: str, dgs_key: str) -> None:
def test_pt21(cfg: Config, dl_key: str, paths: list, out_key: str, dgs_key: str) -> None:
"""Set the PT21 config."""
crop_h, crop_w = config[dl_key]["crop_size"]
config[dl_key]["crops_folder"] = (
config[dl_key]["base_path"]
crop_h, crop_w = cfg[dl_key]["crop_size"]
cfg[dl_key]["crops_folder"] = (
cfg[dl_key]["base_path"]
.replace("posetrack_data", f"crops/{crop_h}x{crop_w}")
.replace(f"{crop_h}x{crop_w}_", "") # remove redundant from crop folder name iff existing
)
Expand All @@ -77,79 +77,77 @@ def test_pt21(config: Config, dl_key: str, paths: list, out_key: str, dgs_key: s
for sub_datapath in (pbar_data := tqdm(paths, desc="ds_sub_dir", leave=False)):
pbar_data.set_postfix_str(os.path.basename(sub_datapath))
# make sure to have a unique log dir every time
orig_log_dir = config["log_dir"]
orig_log_dir = cfg["log_dir"]

# change config data
config[dl_key]["data_path"] = sub_datapath
config["log_dir"] += f"./{out_key}/{dgs_key}/"
config["test"]["submission"] = ["submission_pt21"]
cfg[dl_key]["data_path"] = sub_datapath
cfg["log_dir"] += f"./{out_key}/{dgs_key}/"
cfg["test"]["submission"] = ["submission_pt21"]

# set the new path for the out file in the log_dir
subm_key = "submission_pt21"
config[subm_key]["file"] = os.path.abspath(
os.path.normpath(
f"{config['log_dir']}/results_json/{sub_datapath.split('/')[-1].removesuffix('.json')}.json"
)
cfg[subm_key]["file"] = os.path.abspath(
os.path.normpath(f"{cfg['log_dir']}/results_json/{sub_datapath.split('/')[-1].removesuffix('.json')}.json")
)

if os.path.exists(config[subm_key]["file"]):
if os.path.exists(cfg[subm_key]["file"]):
# reset the original log dir
config["log_dir"] = orig_log_dir
cfg["log_dir"] = orig_log_dir
continue

engine = get_dgs_engine(config=config, dl_keys=(None, None, dl_key), dgs_key=dgs_key)
engine = get_dgs_engine(cfg=cfg, dl_keys=(None, None, dl_key), dgs_key=dgs_key)
engine.test()
# end processes
engine.terminate()

# reset the original log dir
config["log_dir"] = orig_log_dir
cfg["log_dir"] = orig_log_dir


# @torch_memory_analysis
@notify_on_completion_or_error(min_time=30, info="run initial weight")
@t.no_grad()
def test_dance(config: Config, dl_key: str, paths: list, out_key: str, dgs_key: str) -> None:
def test_dance(cfg: Config, dl_key: str, paths: list, out_key: str, dgs_key: str) -> None:
"""Set the DanceTrack config."""

# get all the sub folders or files and analyze them one-by-one
for sub_datapath in (pbar_data := tqdm(paths, desc="ds_sub_dir", leave=False)):
dataset_path = os.path.normpath(os.path.dirname(os.path.dirname(sub_datapath)))
dataset_name = os.path.basename(dataset_path)
pbar_data.set_postfix_str(dataset_name)
config[dl_key]["data_path"] = sub_datapath
cfg[dl_key]["data_path"] = sub_datapath

# make sure to have a unique log dir every time
orig_log_dir = config["log_dir"]
orig_log_dir = cfg["log_dir"]

# change config data
config["log_dir"] += f"./{out_key}/{dgs_key}/"
config["test"]["writer_log_dir_suffix"] = f"./{os.path.basename(sub_datapath)}/"
cfg["log_dir"] += f"./{out_key}/{dgs_key}/"
cfg["test"]["writer_log_dir_suffix"] = f"./{os.path.basename(sub_datapath)}/"

# set the new path for the submission file
subm_key = "submission_MOT"
config["test"]["submission"] = [subm_key]
config[subm_key]["file"] = os.path.abspath(
cfg["test"]["submission"] = [subm_key]
cfg[subm_key]["file"] = os.path.abspath(
os.path.normpath(f"{os.path.dirname(dataset_path)}./results_{out_key}_{dgs_key}/{dataset_name}.txt")
)

if os.path.exists(config[subm_key]["file"]):
if os.path.exists(cfg[subm_key]["file"]):
# reset the original log dir
config["log_dir"] = orig_log_dir
cfg["log_dir"] = orig_log_dir
continue

engine = get_dgs_engine(config=config, dl_keys=(None, None, dl_key), dgs_key=dgs_key)
engine = get_dgs_engine(cfg=cfg, dl_keys=(None, None, dl_key), dgs_key=dgs_key)
engine.test()
# end processes
engine.terminate()

# reset the original log dir
config["log_dir"] = orig_log_dir
cfg["log_dir"] = orig_log_dir


@t.no_grad()
def get_dgs_engine(
config: Config,
cfg: Config,
dl_keys: tuple[str | None, str | None, str | None],
engine_key: str = "engine",
dgs_key: str = "DGSModule",
Expand All @@ -161,17 +159,17 @@ def get_dgs_engine(

# the DGSModule will load all the similarity modules internally
kwargs = {
"model": module_loader(config=config, module_class="dgs", key=dgs_key),
"model": module_loader(config=cfg, module_class="dgs", key=dgs_key),
}
# validation dataset
if key_train is not None:
kwargs["train_loader"] = module_loader(config=config, module_class="dataloader", key=key_train)
kwargs["train_loader"] = module_loader(config=cfg, module_class="dataloader", key=key_train)
if key_eval is not None:
kwargs["val_loader"] = module_loader(config=config, module_class="dataloader", key=key_eval)
kwargs["val_loader"] = module_loader(config=cfg, module_class="dataloader", key=key_eval)
if key_test is not None:
kwargs["test_loader"] = module_loader(config=config, module_class="dataloader", key=key_test)
kwargs["test_loader"] = module_loader(config=cfg, module_class="dataloader", key=key_test)

return module_loader(config=config, module_class="engine", key=engine_key, **kwargs)
return module_loader(config=cfg, module_class="engine", key=engine_key, **kwargs)


def train_dgs_engine(cfg: Config, dl_train_key: str, dl_eval_key: str, alpha_mod_name: str, sim_name: str) -> None:
Expand Down Expand Up @@ -200,7 +198,7 @@ def train_dgs_engine(cfg: Config, dl_train_key: str, dl_eval_key: str, alpha_mod
print(f"Training on the ground-truth train-dataset with config: {dl_train_key} - {alpha_mod_name}")

# use the modified config and obtain the model used for training
engine_train = get_dgs_engine(config=cfg, dl_keys=(dl_train_key, dl_eval_key, None))
engine_train = get_dgs_engine(cfg=cfg, dl_keys=(dl_train_key, dl_eval_key, None))

# set model and initialize the weights
engine_train.model.combine.alpha_model = nn.ModuleList([ALPHA_MODULES[alpha_mod_name]])
Expand Down

0 comments on commit 9f36d1e

Please sign in to comment.