再也不用擔心過擬合的問題了

2021-10-22 12:27:43 字數 3736 閱讀 7116

編譯:ronghuaiyang

使用sam(銳度感知最小化),優化到損失的最平坦的最小值的地方,增強泛化能力。

動機來自先前的工作,在此基礎上,我們提出了一種新的、有效的方法來同時減小損失值和損失的銳度。具體來說,在我們的處理過程中,進行銳度感知最小化(sam),在領域內尋找具有均勻的低損失值的引數。這個公式產生了乙個最小-最大優化問題,在這個問題上梯度下降可以有效地執行。我們提出的實證結果表明,sam在各種基準資料集上都改善了的模型泛化。

在深度學習中,我們使用sgd/adam等優化演算法在我們的模型中實現收斂,從而找到全域性最小值,即訓練資料集中損失較低的點。但等幾種研究表明,許多網路可以很容易地記住訓練資料並有能力隨時overfit,為了防止這個問題,增強泛化能力,谷歌研究人員發表了一篇新**叫做sharpness awareness minimization,在cifar10上以及其他的資料集上達到了最先進的結果。

在本文中,我們將看看為什麼sam可以實現更好的泛化,以及我們如何在pytorch中實現sam。

在梯度下降或任何其他優化演算法中,我們的目標是找到乙個具有低損失值的引數。但是,與其他常規的優化方法相比,sam實現了更好的泛化,它將重點放在領域內尋找具有均勻的低損失值的引數(而不是只有引數本身具有低損失值)上。

由於計算鄰域引數而不是計算單個引數,損失超平面比其他優化方法更平坦,這反過來增強了模型的泛化。

(左))用sgd訓練的resnet收斂到的乙個尖銳的最小值。(右)用sam訓練的相同的resnet收斂到的乙個平坦的最小值。

注意:sam不是乙個新的優化器,它與其他常見的優化器一起使用,比如sgd/adam。

在pytorch中實現sam非常簡單和直接

import torch

class sam(torch.optim.optimizer):

def __init__(self, params, base_optimizer, rho=0.05, **kwargs):

assert rho >= 0.0, f"invalid rho, should be non-negative: "

defaults = dict(rho=rho, **kwargs)

super(sam, self).__init__(params, defaults)

self.base_optimizer = base_optimizer(self.param_groups, **kwargs)

self.param_groups = self.base_optimizer.param_groups

@torch.no_grad()

def first_step(self, zero_grad=false):

grad_norm = self._grad_norm()

for group in self.param_groups:

scale = group["rho"] / (grad_norm + 1e-12)

for p in group["params"]:

if p.grad is none: continue

e_w = p.grad * scale.to(p)

p.add_(e_w)  # climb to the local maximum "w + e(w)"

self.state[p]["e_w"] = e_w

if zero_grad: self.zero_grad()

@torch.no_grad()

def second_step(self, zero_grad=false):

for group in self.param_groups:

for p in group["params"]:

if p.grad is none: continue

p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"

self.base_optimizer.step()  # do the actual "sharpness-aware" update

if zero_grad: self.zero_grad()

def _grad_norm(self):

shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism

norm = torch.norm(

torch.stack([

p.grad.norm(p=2).to(shared_device)

for group in self.param_groups for p in group["params"]

if p.grad is not none

]),p=2

)return norm

**取自非官方的pytorch實現。

**解釋:

...雖然sam的泛化效果較好,但是這種方法的主要缺點是,由於前後兩次計算銳度感知梯度,需要花費兩倍的訓練時間。除此之外,sam還在最近發布的nfnets上證明了它的效果,這是imagenet目前的最高水平,在未來,我們可以期待越來越多的**利用這一技術來實現更好的泛化。

—end—

英文原文:

喜歡的話,請給我個在看吧

css選擇器 看完再也不用擔心選擇器問題了

1.萬用字元選擇器 萬用字元選擇器用來匹配到所有的元素 一般用法 用來清除元素預設的內外邊距,使得樣式統一 2.標籤選擇器 標籤選擇器就是用來匹配到相對應的html標籤 比如 選中所有標籤為p的元素,把其文字的樣式修改為粉色pink p3.類選擇器 選中類名為指定類名的所有元素 比如 選中類名為bo...

Flex,再也不用擔心頁面布局了

布局的傳統解決方案,基於盒狀模型,依賴display屬性 position屬性 float屬性。它對於那些特殊布局非常不方便,比如,垂直居中就不容易實現。flex 是 flexible box 的縮寫,意為 彈性布局 用來為盒狀模型提供最大的靈活性。任何乙個容器都可以指定為 flex 布局。box ...

媽媽再也不用擔心我使用git了

git由於其靈活,速度快,離線工作等特點而倍受青睞,下面一步步來總結下git的基本命令和常用操作。你可以在本地操作git,也可以在遠端伺服器倉庫操作git,例如github,這樣你就需要配置下ssh key,詳情請檢視官方文件說明generating ssh keys 本地轉殖 git clone ...