構建機器學習演算法和剪枝

2021-08-22 02:45:48 字數 1795 閱讀 3035

幾乎所有的深度學習演算法都可以被描述為乙個相當簡單的配方:特定的資料集、代價函式、優化過程和模型。

在大多數情況下,優化演算法可以定義為求解代價函式梯度為零的正規方程。

通常代價函式至少含有一項使學習過程進行統計估計的成分。最常見的代價函式是負對數似然,最小化代價函式導致的最大似然估計。

組合模型、代價和優化演算法來構建學習演算法的配方同時適用於監督學習和無監督學習。在某些情況下,由於計算原因,我們不能實際計算代價函式。在這種情況下,只要我們有近似其梯度的方法,那麼我們仍然可以使用迭代數值優化近似最小化目標。

隨著2023年alextnet的橫空出世,其遠超傳統方法的分類準確率讓深度學習重新走進人們的視野。為提高網路的效能,神經網路逐漸朝著越來越深和越來越寬的方向發展,引數數量隨之激增。例如vgg對alexnet進行改進,在簡單二分類問題上,實測網路模型大小達到了150mb。而在資源受限的場景中,例如移動裝置,計算和儲存資源有限,大尺寸網路的模型難以應用。

圖1在神經網路中,並不是所有的引數都有著相同的作用。部分引數幅值大,對網路輸出的影響顯著,某些引數權重接近0,對網路輸出的影響微乎其微。如果在神經網路中,找到並刪除這些對網路影響較小的引數,那麼可以在保持網路效能的前提下,減小網路模型的尺寸,提高模型的計算速度。如何在找到對網路影響較小的引數,如何刪除網路的部分引數或結構,怎樣在減小網路尺寸的同時,保證網路的效能,這就是剪枝要完成的工作。

剪枝顧名思義,就是剪去網路中不重要的部分,減小網路的尺寸,簡化網路的結構.在實施時,首先通過設定規則,找到要減去的網路結構,在減去該部分結構後,網路效能會受到影響,因此需要重新訓練,恢復網路的效能.

圖2在對圖2左圖所示網路進行剪枝,對網路結構中任一神經元的每乙個連線(接收上一層的輸出),將其與設定的閾值相比較,如果大於該閾值,則保留該鏈結,如果小與該閾值,則丟棄該連線。當乙個神經元大所有連線都被丟棄時,該神經元同樣被丟棄。按照這一規則剪枝後的網路結構如圖2右所示。

剪枝時閾值的由稀疏度決定,而連線的權重與閾值比較的標準有l1 norm,l2 norm等,最常用的是l1 norm,其簡單容易操作。

該方法的實施依賴於mask實現,在tensorflow中有pruning介面,供實現該剪枝方法。但該方法的優點是原理簡單,卷積和全連線層的操作相同,資料壓縮程度大,對精度的影響小。但該方法在實施時步驟繁瑣,需要分層和迭代的剪枝。此外,由於該方法借助於mask實現,因此模型的大小和在模型的速度難以有明顯的提公升。

上節卷積核剪枝中剪枝操作借助mask實現的主要原因是要減去的網路結構是不規則的,如果每次剪去一整個網路單元,或者乙個全連線的神經元,那麼剪枝的操作將被簡化,同時不借助mask實現起來也更為便捷。  

卷積核剪枝的基本思路即是如此,每次剪去一整個卷積核。具體的方法是,對某卷積層而言,其中包含了許多卷積核,如上圖的kernel matrix的列所示。在對某層進行剪枝時,分別計算每個卷積核的l1 norm(卷積核的所有權值的絕對值之和),小於某閾值的卷積核將被丟棄(注意,計算卷積核的l1 norm時應使用剪枝前的卷積核權值)。該閾值的設定方法與1類似,由稀疏度決定。每剪去乙個卷積核,該卷積層輸出的feature map將會減少乙個,為保證輸入feature map和權值維度匹配,應對下一層的輸入權值進行調整,將下一層每個kernel中對該feature map進行卷積操作的channel刪除,如圖3中右邊的kernel matrix所示。

需要注意的是,如果網路中使用到了batch norm,則在剪枝時要對batch norm的引數一併處理。

在實施時,有兩種訓練的方法,一種是從靠近輸入層開始,逐層進行剪枝,每剪完一層,重新訓練後再進行下一層的剪枝。這種方法的剪枝得到的模型精度較高。另一種方法是同時對所有層進行剪枝,每次剪枝時設定較小的稀疏度,重訓後,再設定小稀疏度進行剪枝,如此迭代。這種方法操作起來更簡單。

機器學習 樹的剪枝策略

決策樹剪枝分前剪枝 預剪枝 和後剪枝兩種形式.決策樹為什麼 why 要剪枝?原因是避免決策樹過擬合 overfitting 樣本。前面的演算法生成的決策樹非常詳細並且龐大,每個屬性都被詳細地加以考慮,決策樹的樹葉節點所覆蓋的訓練樣本都是 純 的。因此用這個決策樹來對訓練樣本進行分類的話,你會發現對於...

機器學習系統構建

首先是機器學習系統構建的流程 ng推薦方法 首先高速實現乙個可能並非非常完美的演算法系統。進行交叉驗證,畫出學習曲線去學習演算法問題之處,是high bias or high variance 細節看這篇博文介紹 bias和variance在機器學習中應用 最重要一步 錯誤分析。手工檢驗演算法錯誤學...

機器學習 構建機器學習流水線

from sklearn.datasets import samples generator from sklearn.ensemble import randomforestclassifier from sklearn.feature selection import selectkbest,f...