diff --git a/train.py b/train.py index ae2298fc..cce87dd1 100644 --- a/train.py +++ b/train.py @@ -29,12 +29,19 @@ 'eager_fit: model.fit(run_eagerly=True), ' 'eager_tf: custom GradientTape') flags.DEFINE_enum('transfer', 'none', - ['none', 'darknet', 'no_output', 'frozen', 'fine_tune'], - 'none: Training from scratch, ' - 'darknet: Transfer darknet, ' - 'no_output: Transfer all but output, ' - 'frozen: Transfer and freeze all, ' - 'fine_tune: Transfer all and freeze darknet only') + ['none', 'yolo_darknet', 'yolo_conv', 'yolo_output_conv', 'all'], + 'none: Training from scratch (no weights transfer), ' + 'yolo_darknet: Transfer darknet sub-model weights, ' + 'yolo_conv: Transfer darknet and conv sub-model weights, ' + 'yolo_output_conv: Transfer darknet and conv sub-model weights and first output conv layer weights, ' + 'all: Transfer all weights (pretrained weights need to have the same number of classes)') +flags.DEFINE_enum('freeze', 'none', + ['none', 'yolo_darknet', 'yolo_conv', 'yolo_output_conv', 'all'], + 'none: Tune all weights, ' + 'yolo_darknet: Tune all but darknet sub-model weights, ' + 'yolo_conv: Tune output sub-model weights, ' + 'yolo_output_conv: Tune only output sub-model without the first conv layer, ' + 'all: Do not allow tuning of weights') flags.DEFINE_integer('size', 416, 'image size') flags.DEFINE_integer('epochs', 2, 'number of epochs') flags.DEFINE_integer('batch_size', 8, 'batch size') @@ -81,44 +88,70 @@ def main(_argv): dataset.transform_targets(y, anchors, anchor_masks, FLAGS.size))) # Configure the model for transfer learning - if FLAGS.transfer == 'none': - pass # Nothing to do - elif FLAGS.transfer in ['darknet', 'no_output']: - # Darknet transfer is a special case that works - # with incompatible number of classes - - # reset top layers - if FLAGS.tiny: - model_pretrained = YoloV3Tiny( - FLAGS.size, training=True, classes=FLAGS.weights_num_classes or FLAGS.num_classes) - else: - model_pretrained = YoloV3( - FLAGS.size, training=True, classes=FLAGS.weights_num_classes or FLAGS.num_classes) - model_pretrained.load_weights(FLAGS.weights) + if FLAGS.transfer != 'none': + # if we need all weights, no need to create another model + if FLAGS.transfer == 'all': + model.load_weights(FLAGS.weights) - if FLAGS.transfer == 'darknet': - model.get_layer('yolo_darknet').set_weights( - model_pretrained.get_layer('yolo_darknet').get_weights()) - freeze_all(model.get_layer('yolo_darknet')) + # else, we need only some of the weights + # create appropriate model_pretrained, load all weights and copy the ones we need + else: + if FLAGS.tiny: + model_pretrained = YoloV3Tiny(FLAGS.size, training=True, classes=FLAGS.weights_num_classes or FLAGS.num_classes) + else: + model_pretrained = YoloV3(FLAGS.size, training=True, classes=FLAGS.weights_num_classes or FLAGS.num_classes) + # load pretrained weights + model_pretrained.load_weights(FLAGS.weights) + # transfer darknet + darknet = model.get_layer('yolo_darknet') + darknet.set_weights(model_pretrained.get_layer('yolo_darknet').get_weights()) + # transfer 'yolo_conv_i' layer weights + if FLAGS.transfer in ['yolo_conv', 'yolo_output_conv']: + for l in model.layers: + if l.name.startswith('yolo_conv'): + model.get_layer(l.name).set_weights(model_pretrained.get_layer(l.name).get_weights()) + # transfer 'yolo_output_i' first conv2d layer + if FLAGS.transfer == 'yolo_output_conv': + # transfer tiny output conv2d + if FLAGS.tiny: + # get and set the weights of the appropriate layers + model.layers[4].layers[1].set_weights(model_pretrained.layers[4].layers[1].get_weights()) + model.layers[5].layers[1].set_weights(model_pretrained.layers[5].layers[1].get_weights()) + # should I freeze batch_norm as well? + else: + # get and set the weights of the appropriate layers + model.layers[5].layers[1].set_weights(model_pretrained.layers[5].layers[1].get_weights()) + model.layers[6].layers[1].set_weights(model_pretrained.layers[6].layers[1].get_weights()) + model.layers[7].layers[1].set_weights(model_pretrained.layers[7].layers[1].get_weights()) + # should I freeze batch_norm as well? + # no transfer learning + else: + pass - elif FLAGS.transfer == 'no_output': + # freeze layers, if requested + if FLAGS.freeze != 'none': + if FLAGS.freeze == 'all': + freeze_all(model) + if FLAGS.freeze in ['yolo_darknet' 'yolo_conv', 'yolo_output_conv']: + darknet = model.get_layer('yolo_darknet') + freeze_all(darknet) + if FLAGS.freeze in ['yolo_conv', 'yolo_output_conv']: for l in model.layers: - if not l.name.startswith('yolo_output'): - l.set_weights(model_pretrained.get_layer( - l.name).get_weights()) + if l.name.startswith('yolo_conv'): freeze_all(l) - + if FLAGS.freeze == 'yolo_output_conv': + if FLAGS.tiny: + # freeze the appropriate layers + freeze_all(model.layers[4].layers[1]) + freeze_all(model.layers[5].layers[1]) + else: + # freeze the appropriate layers + freeze_all(model.layers[5].layers[1]) + freeze_all(model.layers[6].layers[1]) + freeze_all(model.layers[7].layers[1]) + # freeze nothing else: - # All other transfer require matching classes - model.load_weights(FLAGS.weights) - if FLAGS.transfer == 'fine_tune': - # freeze darknet and fine tune other layers - darknet = model.get_layer('yolo_darknet') - freeze_all(darknet) - elif FLAGS.transfer == 'frozen': - # freeze everything - freeze_all(model) - + pass optimizer = tf.keras.optimizers.Adam(lr=FLAGS.learning_rate) loss = [YoloLoss(anchors[mask], classes=FLAGS.num_classes) for mask in anchor_masks]