Pytorch實現模型蒸餾

2021-10-25 22:24:33 字數 2541 閱讀 8201

簡單記錄一下使用pytorch進行模型蒸餾的主要**,,其餘資料處理的內容可以另行補充

import torch

import torch.nn as nn

import numpy as np

from torch.nn import crossentropyloss

from torch.utils.data import tensordataset,dataloader,sequentialsampler

class

model

(nn.module)

: def __init__

(self,input_dim,hidden_dim,output_dim)

:super

(model,self)

.__init__()

self.layer1 = nn.

lstm

(input_dim,hidden_dim,output_dim,batch_first = true)

self.layer2 = nn.

linear

(hidden_dim,output_dim)

def forward

(self,inputs)

: layer1_output,layer1_hidden = self.

layer1

(inputs)

layer2_output = self.

layer2

(layer1_output)

layer2_output = layer2_output[:,

-1,:

]#取出乙個batch中每個句子最後乙個單詞的輸出向量即該句子的語義向量!!!!!!!!!

return layer2_output

#建立小模型

model_student =

model

(input_dim =

2,hidden_dim =

8,output_dim =4)

#建立大模型(此處仍然使用lstm代替,可以使用訓練好的bert等複雜模型)

model_teacher =

model

(input_dim =

2,hidden_dim =

16,output_dim =4)

#設定輸入資料,此處只使用隨機生成的資料代替

inputs = torch.

randn(4

,6,2

)true_label = torch.

tensor([

0,1,

0,0]

)#生成dataset

dataset =

tensordataset

(inputs,true_label)

#生成dataloader

sampler =

sequentialsampler

(inputs)

dataloader =

dataloader

(dataset = dataset,sampler = sampler,batch_size =2)

loss_fun =

crossentropyloss()

criterion = nn.

kldivloss

()#kl散度

optimizer = torch.optim.

sgd(model_student.

parameters()

,lr =

0.1,momentum =

0.9)#優化器,優化器中只傳入了學生模型的引數,因此此處只對學生模型進行引數更新,正好實現了教師模型引數不更新的目的

for step,batch in

enumerate

(dataloader)

: inputs = batch[0]

labels = batch[1]

#分別使用學生模型和教師模型對輸入資料進行計算

output_student =

model_student

(inputs)

output_teacher =

model_teacher

(inputs)

#計算學生模型和真實標籤之間的交叉熵損失函式值

loss_hard =

loss_fun

(output_student,labels)

#計算學生模型**結果和教師模型**結果之間的kl散度

loss_soft =

criterion

(output_student,output_teacher)

loss =

0.9*loss_soft +

0.1*loss_hard

print

(loss)

optimizer.

zero_grad()

loss.

backward()

optimizer.

step

()

pytorch設計模型

1.nn.modulelist使對於加入其中的子模組,不必在forward中依次呼叫 nn.sequentialt使對於加入其中的子模組在forward中可以通過迴圈實現呼叫 2.pytorch中nn.modulelist和nn.sequential的用法和區別 nn.sequential定義的網路...

pytorch自動求導實現感知機模型

感知機理論知識鏈結。usr bin env python encoding utf 8 created on 2019 12 31 21 17 author phil import torch.nn as nn import torch import numpy as np import torch...

pytorch模型儲存的2種實現方法

1 儲存整個網路結構資訊和模型引數資訊 torch.s e model object,vghvryahy model.pth 直接載入即可使用 model tor程式設計客棧ch.load model.pth 2 只儲存網路的模型引數 推薦使用 www.cppcns.comtorch.s e mod...