From 415a7978bb07bf97eb9e94660a94a3e884656405 Mon Sep 17 00:00:00 2001 From: "Eom, Jihwan" Date: Fri, 11 Nov 2022 00:40:01 +0900 Subject: [PATCH] Move drop_last in cls trainer.py --- mpa/cls/trainer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mpa/cls/trainer.py b/mpa/cls/trainer.py index 2ca576a2..c17d26fe 100644 --- a/mpa/cls/trainer.py +++ b/mpa/cls/trainer.py @@ -151,8 +151,12 @@ def train_worker(gpu, dataset, cfg, distributed, validate, timestamp, meta): # prepare data loaders dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] train_data_cfg = Stage.get_data_cfg(cfg, "train") - drop_last = train_data_cfg.drop_last if train_data_cfg.get('drop_last', False) else False - + ote_dataset = train_data_cfg.get('ote_dataset', None) + drop_last = False + dataset_len = len(ote_dataset) if ote_dataset else 0 + # if task == h-label & dataset size is bigger than batch size + if train_data_cfg.get('hierarchical_info', None) and dataset_len > cfg.data.get('samples_per_gpu', 2): + drop_last = True # updated to adapt list of dataset for the 'train' data_loaders = [] sub_loaders = []