Skip to content

Commit

Permalink
removed unnecessary vars, new checjpoint save, and new standard refer…
Browse files Browse the repository at this point in the history
  • Loading branch information
camilo-nunez committed Jul 23, 2023
1 parent 0b26c96 commit aa56085
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions train_A.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import albumentations as A
from albumentations.pytorch import ToTensorV2

from model.builder import BackboneNeck, AVAILABLE_NECKS, AVAILABLE_BACKBONES
from config.init import create_train_config_A
from model.builder import BackboneNeck
from config.init import create_train_config
from utils.datasets import VOCDetectionV2, CocoDetectionV2

AVAILABLE_DATASETS = ['coco2017', 'voc2012']
Expand Down Expand Up @@ -79,7 +79,7 @@ def parse_option():
help="Display the summary of the model.")

args, unparsed = parser.parse_known_args()
config = create_train_config_A(args)
config = create_train_config(args)

return args, config

Expand Down Expand Up @@ -117,7 +117,7 @@ def parse_option():
_num_classes = len(base_config.DATASET.OBJ_LIST)
print(f'[++] Numbers of classes: {_num_classes}')
base_model = torchvision.models.detection.FasterRCNN(backbone_neck,
num_classes=_num_classes,
num_classes=_num_classes + 1, # +1 = background
rpn_anchor_generator=anchor_generator,
box_roi_pool=roi_pooler).to(device)
print('[+] Ready !')
Expand Down Expand Up @@ -242,19 +242,18 @@ def parse_option():

global_steps+=1

# if loss_median < best_loss:
# best_loss = loss_median
if loss_median < best_loss:
best_loss = loss_median

# torch.save({'model_state_dict': base_model.state_dict(),
# 'neck_state_dict': base_model.backbone.fpn_backbone.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
# 'epoch': epoch,
# 'best_loss': best_loss,
# 'fn_cfg_dataset': str(args.cfg_dataset),
# 'fn_cfg_model': str(args.cfg_model),
# 'fpn_type': base_config.MODEL.BIFPN.TYPE,
# },
# os.path.join(args.checkpoint_path, f'{datetime.utcnow().strftime("%Y%m%d_%H%M")}_A1_{base_config.MODEL.BIFPN.TYPE}_{base_config.MODEL.BACKBONE.NAME}_{base_config.MODEL.BIFPN.NAME}_{epoch}.pth'))
torch.save({'model_state_dict': base_model.state_dict(),
'neck_state_dict': base_model.backbone.neck.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'best_loss': best_loss,
'fn_cfg_dataset': str(args.cfg_dataset),
'fn_cfg_model': str(args.cfg_model),
},
os.path.join(args.checkpoint_path, f'{datetime.utcnow().strftime("%Y%m%d_%H%M")}_A_{base_config.MODEL.BACKBONE.MODEL_NAME}_{base_config.MODEL.NECK.MODEL_NAME}_{epoch}.pth'))

end_t = datetime.now()
print('[+] Ready, the train phase took:', (end_t - start_t))
Expand Down

0 comments on commit aa56085

Please sign in to comment.