蒸餾法訓練網路

2021-09-13 00:18:48 字數 1447 閱讀 2092

在ml領域中有一種最為簡單的提公升模型效果的方式,在同一訓練集上訓練多個不同的模型,在**階段採用綜合均值作為**值。但是,運用這樣的組合模型需要太多的計算資源,特別是當單個模型都非常大的時候。已經有相關的研究表明,複雜模型或者組合模型的中「知識」通過合適的方式是可以遷移到乙個相對簡單模型之中,進而方便模型推廣部署。

在大規模的機器學習領域,如物體檢測、語音識別等為了獲得較好的performance常常會訓練很複雜的模型,因為不需要考慮實時性、計算量等因素。但是,在部署階段就需要考慮模型的大小、計算複雜度、速度等諸多因素,因此我們需要更小更精煉的模型用於部署。這種訓練和部署階段不同的模型形態,可以模擬於自然界中很多昆蟲有多種形態以適應不同階段的需求。具體地,如蝴蝶在幼蟲以蛹的形式儲存能量和營養來更好的發育,但是到了後期就為了更好的繁殖和移動它就呈現了另外一種完全不一樣的形態。

有一種直觀的概念就是,越是複雜的網路具有越好的描述能力,可以用來解決更為複雜的問題。我們所說的模型學習得到「知識」就是模型引數,說到底我們想要學習的是乙個輸入向量到輸出向量的對映,而不必太過於去關心中間對映過程。

模型蒸餾就是將訓練好的複雜模型的推廣能力「知識」遷移到乙個結構更為簡單的網路中。或者通過簡單的網路去學習複雜模型中「知識」。其基本流程如下圖:

基本可以分為兩個階段:

原始模型訓練:

a1. 根據提出的目標問題,設計乙個或多個複雜網路(n1,n2,…,nt)。

a2. 使用足夠的訓練資料,按照常規cnn模型訓練流程,並行訓練多個複雜網路,得到(m1,m2,…,mt)

精簡模型訓練:

b1.      根據(n1,n2,…,nt)設計乙個簡單網路n0。

b2.      使用簡單模型訓練資料,此處的訓練資料可以是訓練原始網路的有標籤資料,也可以是額外的無標籤資料。

b3.      將a2中收集到的樣本輸入原始模型(m1,m2,…,mt),修改原始模型softmax層中溫度引數t為乙個較大值,如t=20。每乙個樣本在每個原始模型可以得到其最終的分類概率向量,選取其中概率至最大即為該模型對於當前樣本的判定結果。對於t個原始模型就可以得到t個概率向量。然後對t概率向量求取均值作為當前樣本最後的概率輸出向量,記為soft_target,儲存。

b4.  標籤融合b2中收集到的資料定義為hard_target,有標籤資料的hard_target取值為其標籤值1,無標籤資料hard_taret取值為0。target =a*hard_target + b*soft_target(a+b=1)。target最終作為訓練資料的標籤去訓練精簡模型。引數a,b是用於控制標籤融合權重的,推薦經驗值為(a=0.1 b=0.9)

5. 設定精簡模型softmax層溫度引數與原始複雜模型產生soft-target時所採用的溫度,按照常規模型訓練精簡網路模型。

6. 部署時將精簡模型中的softmax溫度引數重置為1,即採用最原始的softmax

知識蒸餾,緊湊的網路結構簡單記錄

知識蒸餾 遷移學習的一種,目的是將龐大網路學到的知識轉移到小的網路模型上,即不改變網路複雜度的情況下,通過增加監督資訊的豐富程度來提公升效能。關鍵點 1.知識獲取 2.知識轉移 常見集中思想 1.softmax層的輸入比類別標籤包含更多的監督資訊,使用logistics代替類別標籤對小模型進行訓練,...

批量訓練網路

如果整個資料庫中的數量不是每批資料數量的整數倍,體統會將剩餘的放入最後一批 import torch import torch.utils.data as data torch.manual seed 1 reproducible batch size 5x torch.linspace 1,10,...

演算法訓練 JAM計數法

演算法訓練 jam計數法 時間限制 1.0s 記憶體限制 256.0mb 提交此題 錦囊1 錦囊2 問題描述 jam是個喜歡標新立異的科學怪人。他不使用阿拉伯數字計數,而是使用小寫英文本母計數,他覺得這樣做,會使世界更加豐富多彩。在他的計數法中,每個數字的位數都是相同的 使用相同個數的字母 英文本母...