如何正確的理解RPN網路的train和test

2021-08-21 23:56:02 字數 4296 閱讀 2420

剛開始學faster rcnn時,遇到些困惑不知其他人有沒有:

1、rpn網路訓練的輸出是什麼?

2、rpn網路在train中的作用是什麼?

3、rpn網路在test中的作用是什麼?

其實這些我們如果不看原始碼都很難真正理解!

以faster-rcnn_tf的原始碼為例,以下**取自./lib/networks/vggnet_train.py

#*****==== rpn **********==

#以下**的先後順序我調整了一下,便於理解

(self.feed('conv5_3')

.conv(3,3,512,1,1,name='rpn_conv/3x3')

.conv(1,1,len(anchor_scales)*3*2 ,1 , 1, padding='valid', relu = false, name='rpn_cls_score'))

(self.feed('rpn_conv/3x3')

.conv(1,1,len(anchor_scales)*3*4, 1, 1, padding='valid', relu = false, name='rpn_bbox_pred'))

.anchor_target_layer(_feat_stride, anchor_scales, name = 'rpn-data' ))

重點

anchor_target_layer的返回值』rpn-data』,這是乙個字典

key分別是:rpn_labels, rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights

rpn_labels

是 [1,1,a*height,width],如果把它reshape成[1,a,height,width]會更好理解,即feature map上每一點

都是乙個anchor,每個anchor對應a個bbox,如果乙個bbox與gt_box的重疊度大於0.7(其實還有乙個條件),就認為這個bbox包含乙個前景,則

rpn_labels 矩陣中相應位置就設定為1。

gt_box的label不能直接用來做訓練的目標(target),在訓練中使用rpn_labels作為訓練的目標

gt_box的唯一作用就在於判斷產生的共a*w*h個bbox哪些屬於前景,哪些不屬於,將那些屬於前景的bbox設定為訓練目標去訓練rpn_cls_score_reshape。

在test中,正好相反,訓練好的網路會產生乙個rpn_cls_score_reshape,它可以轉化成乙個[1,a,height,width]的矩陣

#proposal_layer 產生的[1,a,height,width]個bbox哪些屬於前景,哪些屬於背景。我們會把屬於前景的挑出來,

按照得分排序,取前300個輸入後面的fc層,fc層會產生兩個輸出:

乙個是cls_score,用於判斷bbox中物體的型別

另乙個是bbox_pred,用於微調bbox,使其向gt_box進一步靠近(由於bbox都是從anchor產生的,他們不會和gt_box重合,還需要進一步微調)

rpn_bbox_targets

根據 rpn_labels 我們已經可以挑選出300個bbox,這些bbox都是在[1,w,h,a*4]中根據與gt_box的重合程度挑選出來的,與gt_box並不相同,有一些偏差,這些偏差表示為[dx,dy,dw,dh],這就是rpn_bbox_targets。

因為傳進後面全卷積網路的是bbox,與gt_boxes不完全重合,為了使最終的結果更加接近gt_box,還需要進一步微調

而全卷積層的輸出bbox_pred就是用於微調的,rpn_bbox_targets就是它訓練的目標(target)

損失函式的計算:

# rpn

# classification loss

rpn_cls_score = tf.reshape(self.net.get_output('rpn_cls_score_reshape'),[-1,2])

rpn_label = tf.reshape(self.net.get_output('rpn-data')[0],[-1])

rpn_cls_score = tf.reshape(tf.gather(rpn_cls_score,tf.where(tf.not_equal(rpn_label,-1))),[-1,2])

rpn_label = tf.reshape(tf.gather(rpn_label,tf.where(tf.not_equal(rpn_label,-1))),[-1])

rpn_cross_entropy = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=rpn_cls_score, labels=rpn_label))

# bounding box regression l1 loss

rpn_bbox_pred = self.net.get_output('rpn_bbox_pred')

rpn_bbox_targets = tf.transpose(self.net.get_output('rpn-data')[1],[0,2,3,1])

rpn_bbox_inside_weights = tf.transpose(self.net.get_output('rpn-data')[2],[0,2,3,1])

rpn_bbox_outside_weights = tf.transpose(self.net.get_output('rpn-data')[3],[0,2,3,1])

rpn_smooth_l1 = self._modified_smooth_l1(3.0, rpn_bbox_pred, rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights)

rpn_loss_box = tf.reduce_mean(tf.reduce_sum(rpn_smooth_l1, reduction_indices=[1, 2, 3]))

其餘**:

# loss of rpn_cls & rpn_boxes

(self.feed('rpn_conv/3x3')

.conv(1,1,len(anchor_scales)*3*4, 1, 1, padding='valid', relu = false, name='rpn_bbox_pred'))

#*****==== roi proposal **********==

(self.feed('rpn_cls_score')

.reshape_layer(2,name = 'rpn_cls_score_reshape')

.softmax(name='rpn_cls_prob'))

(self.feed('rpn_cls_prob')

.reshape_layer(len(anchor_scales)*3*2,name = 'rpn_cls_prob_reshape'))

(self.feed('rpn_cls_prob_reshape','rpn_bbox_pred','im_info')

.proposal_layer(_feat_stride, anchor_scales, 'train',name = 'rpn_rois'))

(self.feed('rpn_rois','gt_boxes')

.proposal_target_layer(n_classes,name = 'roi-data'))

#*****==== rcnn **********==

(self.feed('conv5_3', 'roi-data')

.roi_pool(7, 7, 1.0/16, name='pool_5')

.fc(4096, name='fc6')

.dropout(0.5, name='drop6')

.fc(4096, name='fc7')

.dropout(0.5, name='drop7')

.fc(n_classes, relu=false, name='cls_score')

.softmax(name='cls_prob'))

(self.feed('drop7')

.fc(n_classes*4, relu=false, name='bbox_pred'))

RPN網路的錨是如何生成的

anchor這個問題,我最初也沒弄懂。剛剛看完rbg大神的原始碼,終於明白了,來回答一發,如果有不對的地方請大家指出。以vgg 16改造的faster r cnn為例。py faster r cnn的 model pascal voc vgg16 faster rcnn alt opt faster...

RPN網路的anchor機制

anchors是一組大小固定的參考視窗 三種尺度 三種長寬比,如下圖所示,表示rpn網路中對特徵圖滑窗時每個滑窗位置所對應的原圖區域中9種可能的大小,相當於模板,對任意影象任意滑窗位置都是這9中模板。繼而根據影象大小計算滑窗中心點對應原圖區域的中心點,通過中心點和size就可以得到滑窗位置和原圖位置...

如何正確理解C語言的檔案

c語言的檔案是 c語言的基礎知識,掌握c語言檔案需要了解哪些東西呢?這裡已經給大家詳細列出了知識點 檔案c語言中,把檔案看做乙個字元的序列,也稱字元流 沒有格式 可以簡單認為是分為 文字檔案 以某種編碼儲存顯示的字元 二進位制檔案 以補碼格式儲存 其實是按資料的組織形式來分的 文字檔案 ascii檔...