Skip to content

SpikeKing/triplet-loss-mnist

Repository files navigation

Triplet Loss 损失函数

Triplet Loss是深度学习中的一种损失函数,用于训练差异性较小的样本,如人脸等, Feed数据包括锚(Anchor)示例、正(Positive)示例、负(Negative)示例,通过优化锚示例与正示例的距离小于锚示例与负示例的距离,实现样本的相似性计算。

数据集:MNIST

框架:DL-Project-Template

目标:通过Triplet Loss训练模型,实现手写图像的相似性计算。

工程https://github.com/SpikeKing/triplet-loss-mnist


模型

Triplet Loss的核心是锚示例、正示例、负示例共享模型,通过模型,将锚示例与正示例聚类,远离负示例。

Triplet Loss Model的结构如下:

  • 输入:三个输入,即锚示例、正示例、负示例,不同示例的结构相同;
  • 模型:一个共享模型,支持替换为任意网络结构;
  • 输出:一个输出,即三个模型输出的拼接。

Shared Model选择常用的卷积模型,输出为全连接的128维数据:

Triplet Loss 损失函数的计算公式如下:


训练

模型参数:

  • batch_size:32
  • epochs:2

超参数:

  • 边界Margin的值设置为1

训练命令:

python main_train.py -c configs/triplet_config.json

训练日志:

Using TensorFlow backend.
[INFO] 解析配置...
[INFO] 加载数据...
[INFO] X_train.shape: (60000, 28, 28, 1), y_train.shape: (60000, 10)
[INFO] X_test.shape: (10000, 28, 28, 1), y_test.shape: (10000, 10)
[INFO] 构造网络...
[INFO] model - 锚shape: (?, 128)
[INFO] model - 正shape: (?, 128)
[INFO] model - 负shape: (?, 128)
[INFO] model - triplet_loss shape: (?, 1)
[INFO] 训练网络...
[INFO] trainer - 类别数: 10
Train on 54200 samples, validate on 8910 samples
2018-04-24 10:59:06.130952: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
Epoch 1/2
54200/54200 [==============================] - 84s 2ms/step - loss: 0.6083 - val_loss: 0.0515
Epoch 2/2
54200/54200 [==============================] - 83s 2ms/step - loss: 0.0869 - val_loss: 0.0314

算法收敛较好,Loss线性下降:

TF Graph:


验证

算法效率(TPS): 每秒48163次 (0.0207625 ms/t)

测试命令:

python main_test.py -c configs/triplet_config.json

测试日志:

Using TensorFlow backend.
[INFO] 解析配置...
[INFO] 加载数据...
[INFO] 预测数据...
展示数据: (10000, 784)
展示标签: (10000,)
日志目录: /Users/wang/workspace/triplet-loss-mnist/experiments/triplet_mnist/logs/default
2018-04-24 11:02:07.874682: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
[INFO] model - triplet_loss shape: (?, 1)
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
anc_input (InputLayer)          (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
pos_input (InputLayer)          (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
neg_input (InputLayer)          (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
model_1 (Model)                 (None, 128)          112096      anc_input[0][0]                  
                                                                 pos_input[0][0]                  
                                                                 neg_input[0][0]                  
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 384)          0           model_1[1][0]                    
                                                                 model_1[2][0]                    
                                                                 model_1[3][0]                    
==================================================================================================
Total params: 112,096
Trainable params: 112,096
Non-trainable params: 0
__________________________________________________________________________________________________
TPS: 272553.829381 (0.003669 ms)
验证结果结构: (10000, 128)
展示数据: (10000, 128)
展示标签: (10000,)
日志目录: /Users/wang/workspace/triplet-loss-mnist/experiments/triplet_mnist/logs/test
[INFO] 预测完成...

MNIST验证集的效果:

[INFO] trainer - clz 0
[INFO] trainer - distance - min: -15.4567, max: 1.98611, avg: -6.50481
[INFO] acc: 0.996632996633

[INFO] trainer - clz 1
[INFO] trainer - distance - min: -13.09, max: 3.43779, avg: -6.66867
[INFO] acc: 0.99214365881

[INFO] trainer - clz 2
[INFO] trainer - distance - min: -14.2524, max: 2.49437, avg: -5.60508
[INFO] acc: 0.991021324355

[INFO] trainer - clz 3
[INFO] trainer - distance - min: -16.6555, max: 1.21776, avg: -6.32161
[INFO] acc: 0.995510662177

[INFO] trainer - clz 4
[INFO] trainer - distance - min: -14.193, max: 1.65427, avg: -5.90896
[INFO] acc: 0.991021324355

[INFO] trainer - clz 5
[INFO] trainer - distance - min: -14.1007, max: 2.01843, avg: -6.36086
[INFO] acc: 0.994388327722

[INFO] trainer - clz 6
[INFO] trainer - distance - min: -16.8953, max: 2.84421, avg: -8.43978
[INFO] acc: 0.995510662177

[INFO] trainer - clz 7
[INFO] trainer - distance - min: -16.6177, max: 3.49675, avg: -5.99822
[INFO] acc: 0.989898989899

[INFO] trainer - clz 8
[INFO] trainer - distance - min: -14.937, max: 3.38141, avg: -5.4424
[INFO] acc: 0.979797979798

[INFO] trainer - clz 9
[INFO] trainer - distance - min: -16.9519, max: 2.39112, avg: -5.93581
[INFO] acc: 0.985409652076

测试的MNIST分布:

输出的Triplet Loss MNIST分布:

本例仅仅使用2个Epoch,也没有特殊设置超参,实际效果仍有提升空间。


By C. L. Wang

About

Triplet Loss 损失函数

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages