caffe多工學習之多標籤分類

2021-07-28 23:51:34 字數 1792 閱讀 3505

最近在參加乙個識別的競賽,專案裡涉及了許多類別的分類,原本打算乙個大的類別訓練乙個分類模型,但是這樣會比較麻煩,對於同一的分類會重複計算分類網路中的卷積層,浪費計算時間和效率。後來發現現在深度學習中的多工學習可以實現多標籤分類,所有的類別只需要訓練乙個分類模型就行,其不同屬性的類別之間是共享卷積層的。我所有的專案開發都是基於caffe框架的,預設的,caffe中的data層只支援單維標籤,不支援多標籤分類。我也是參考了大牛的部落格修改了caffe裡面的原始碼,使得caffe支援多標籤分類。下面介紹怎麼在caffe中修改原始碼支援多標籤,包括訓練和測試過程的修改。 

caffe原始碼修改:

std::ifstream infile(argv[2]);

std::vector

> > lines;

std::string filename;

std::string label_count_string = argv[5];

int label_count = std::atoi(label_count_string.c_str());

std::vector

label(label_count);

while (infile >> filename)

// create new db

scoped_ptrdb_image(db::getdb(flags_backend));

scoped_ptrdb_label(db::getdb(flags_backend));

db_image->open(argv[3], db::new);

db_label->open(argv[4], db::new);

scoped_ptrtxn_image(db_image->newtransaction());

scoped_ptrtxn_label(db_label->newtransaction());

訓練模型:

上面我們就有了多工的深度學習的基礎部分資料輸入。為了向上相容caffe框架,我也是參考了大牛的部落格,摒棄了部分開源實現增加data層標籤維度選項並修改data層**的做法,直接使用兩個data層將資料讀入,即分別讀入資料和多維標籤。接下來詳細介紹訓練所需要做的步驟以及和修改。 

1. lmdb的資料製作

由於篇幅的原因,我只貼了部分主要的**圖,注意下圖示紅的部分,第乙個是你多標籤所需要的類別數目,第二個是一些資料的路徑。 

由於現在為了支援多標籤,把資料和標籤分開了,以前的單標籤在data層資料和標籤在一起的對應的(自己的理解)。所以第三和第四初標紅的是test和train最後用於訓練的lmdb資料和對應多維標籤。這製作lmdb指令碼檔案我會放在我的部落格資源上: 

#訓練資料層

name: "caffenet"

layer

transform_param

data_param

}#訓練資料標籤層

layer

data_param

}#測試資料層

layer

transform_param

data_param

}#測試資料標籤層

layer

data_param

}

修改完網路模型的data層,後面需要將標籤資料庫中的內容進行切分,拆分成各個屬性的標籤,需要新增slice層,slice層是將乙個輸入層根據切割指標給定的維度(現在只有num和channel)切割成多個輸出層,如下圖所示。有幾類標籤就定義幾類top並命名不同,用於連線最後的accuracy層。對於slice層的引數:

頂 3 踩

caffe多工學習之多標籤分類

最近在參加乙個識別的競賽,專案裡涉及了許多類別的分類,原本打算乙個大的類別訓練乙個分類模型,但是這樣會比較麻煩,對於同一的分類會重複計算分類網路中的卷積層,浪費計算時間和效率。後來發現現在深度學習中的多工學習可以實現多標籤分類,所有的類別只需要訓練乙個分類模型就行,其不同屬性的類別之間是共享卷積層的...

caffe多工學習之多標籤分類

最近在參加乙個識別的競賽,專案裡涉及了許多類別的分類,原本打算乙個大的類別訓練乙個分類模型,但是這樣會比較麻煩,對於同一的分類會重複計算分類網路中的卷積層,浪費計算時間和效率。後來發現現在深度學習中的多工學習可以實現多標籤分類,所有的類別只需要訓練乙個分類模型就行,其不同屬性的類別之間是共享卷積層的...

多工學習

最近一段時間multitask網路比較流行,比如做人臉檢測的時候,乙個網路完成 人臉和非人臉 二分類任務的同時也要進行boudingbox回歸或者人臉關鍵點回歸。以人臉檢測mtcnn為例,乙個網路包含三個任務。訓練的時候,乙個batch中的,一部分用於二分類 一部分用於boundingbox 回歸,...