-
Notifications
You must be signed in to change notification settings - Fork 5
SCA CNN阅读以及Tensorflow使用中碰到的一些问题
思想:主要是对CNN部分改进,由于CNN具有 spatial, channel-wise, and multi-layer三个性质,而传统的attention-base的模型只在spatial上作了attention,而本文在三个方面都作了attention操作
-
整体上使用的还是CNN encoder-RNN decoder的结构
-
处理部分主要是在CNN添加多种attention。
-
亮点在于Channel-wise attention,实现是对CNN网络第i层的输入Vl ,对其每个channel,也就是fliter,作了一个mean_pooling的操作,得到了每个channel的权值,然后通过一个全连接层和softmax得到了权值。再与输入Vl 作外积。(beta便是attention值)
- Spatial attention也是类似,实现方法是:
假如Vl 为heightwidthchannel,通过flatten把其展成(height*width)个cahnnel维的矢量,再通过全连接层得到attention的权值(alpha便是antettion值)
-
文章还通过实现作了对比,发现先做channel-wise attention再做Spatial attention 效果比,先做Spatial attention再做channel-wise attention效果要好。
-
图像特征在输入RNN的时候,用的是作了c-s attention处理的CNN的最后一层feature作为输入
slim是Tensorflow中的一个高层框架,其预定义了许多art-of-state的模型,比如VGG,Inception,Restnet还能从其网站上下载已经训练好的权重,在作模型的fine tune的时候非常方便。
Dataset是tensorflow新的输入api,可以通过get_next()方法直接获得一条数据的张量,因此不需要构建placeholder便能很轻松的把数据传进图中,能通过map方法对数据预处理,还有shuffle,repeate,batch等非常方便方法。
Dataset通常搭配estimator使用,但是estimator要求以固定的格式去编写模型函数,由于考虑到特定的初始化问题,比如刚开始时载入一部分Inception的参数,其他参数随机初始化,查阅了一下文档后发现,tensorflow提供了一个MonitoredTrainingSession,并提供了一个should_stop的方法判断Dataset是否已经处理完,这样就相对灵活一点。
-
由于attention需要修改CNN模型的结构,需要部分地restore模型的参数,可以通过 tf.saver(var_list)来实现。而为了找到var_list必须遍历已有静态图中可以放入参数的Variables,看看名字是否与checkpoint文件中的名字相同。
-
slim的Variables有自己的命名方式,因此要在slim.arg_scope下读取模型,并且得使用slim的get_variables_by_name
-
提取的image_feature是作为输入,而不是初始状态进入RNN的,RNN初状态置0。
-
LSTM的输出是以LSTMStateTuple的形式输出,相当于一个[cell_state,hidden_state]的一个张量。因此,LMST的state_size也是以[cell_size,hidden_size]的形式给出
-
对于encoder-decoder模型,Batch_Loss是以token为单位的,即一个batch有n个句子,而n个句子中一共有m个token,则Batch_Loss则是m个输入和输出的loss和求均值。这在实现的时候可以通过一个mask矩阵来判断某个句子的某个token是否需要计算进loss里。
-
最新的Dataset API中,读入的数据维度要固定的,但是像caption这种不固定长度的数据,一种处理办法就是先把数据转化成TFrecord类型。而这里有个坑就是tfrecord处理feature时对字符串是按byteslist来处理的,对于python3的字符串,要先用str.decode(s)把srt类型转换成bytes后再处理
-
tensor对元素的操作有时候非常麻烦。在生成mask张量时,即非零元素的位置置1,可以用tf.where函数,如下:
pos=tf.not_equal(input,tf.const(0,tf.int64))
mask=tf.where(pos,tf.fill(tf.shape(pos),1),tf.fill(tf.shape(pos),0))