前兩天學習了huggingface datasets來寫乙個資料載入指令碼,但是,在實驗中發現,使用dataloader載入資料的便捷性,這兩天查資料勉強重寫dataloader載入本地資料,在這裡記錄下,如果有錯誤,歡迎指正。 總結
在pytorch官網搜尋dataloader,返回的一篇教程是
writing custom datasets, dataloaders and transforms講了編寫自定義資料集,資料載入器和轉換,但是講的是影象資料,在這裡我還使用蘇劍林老師在《基於cnn的閱讀理解式問答模型:dgcnn》裡提供的webqa和sogouqa資料集來重寫dataloader,因為蘇劍林老師是用bert4keras庫 來寫資料集載入,我看的是task_reading_comprehension_by_mlm.py的資料載入格式,有部分**我沒有看懂,我按照自己的理解來重寫的。
torch.utils.data.dataset是表示資料集的抽象類。自定義資料集應繼承資料集並覆蓋以下方法:
對自己的資料集首先要建立乙個dataset 類。在__len__
中讀取檔案,要將讀取的檔案載入到__getitem__
中,這樣的話所有資料不會立即儲存到記憶體中,二十根據需要讀取,因此可以提高記憶體效率。視覺化演示可以參考這裡。
torch.utils.data.dataset 是乙個抽象類,它只能被繼承。在b站上講解的有兩個我參考的比較好的教程,飯客帆和劉二大人的pytorch教程都很好。
首先是導包。
import re
import torch
import tokenizers
from torch.utils.data import dataset, dataloader
import numpy as np
這裡有幾個要預定義的超引數:
max_length =
384max_p_len =
256max_q_len =
64max_a_len =
32
在蘇神的示例中
輸入:[cls][mask][mask][sep]問題[sep]篇章[sep]
輸出:答案
先要說明的是,蘇神的這個資料載入的**我沒有完全看懂,全部的**我暫時還沒看完,我先按照自己的理解來寫自己的,等我看完所有的**會回來重新修改這篇文章。蘇神的原始碼等我看懂了會回來修改的。
class
mydataset
(dataset)
:# 繼承dataset模組的dataset類
# 初始化定義,得到資料內容
def__init__
(self, data_set)
:super
(mydataset, self)
.__init__(
) self.data_set = data_set # 載入資料集
self.length =
len(data_set)
# 資料集長度
# 返回資料集大小
def__len__
(self)
:return self.length
# 資料預處理,這部分根據自己的資料集進行處理
def__getitem__
(self, index)
:# index(或item)不能少,這個引數是來挑選某條資料的
# d = self.data_set[index]
# 從data_set中取樣乙個資料
token_ids, segment_ids, a_token_ids =,,
question = d[
'question'
] answers =
[p['answer'
]for p in d[
'passages'
]if p[
'answer']]
# 蘇神的**,我沒有看懂是否挑選了無答案的,這裡是自己改的。挑選帶答案的文章
passage =
""# 先宣告再使用是個好習慣,不然會報錯
for pre_passage in d[
'passages']:
if pre_passage[
'answer']:
passage = pre_passage[
'passage'
]break
passage = re.sub(u' |、|;|,'
,','
, passage)
# 清洗資料
final_answer =
''for answer in answers:
# 選擇答案
ifall
([a in passage[
:max_p_len -2]
for a in answer.split(
' ')])
: final_answer = answer.replace(
' ',
',')
break
# print(question)
a_token_ids = tokenizer.encode(final_answer, max_length=max_a_len +
1, padding=
"max_length"
, truncation=
true
)# 答案編碼
q_token_ids = tokenizer.encode(question, max_length=max_q_len +
1, truncation=
true
)# 對問題進行截斷
p_token_ids = tokenizer.encode(passage, max_length=max_p_len +
1, truncation=
true
) token_ids +=
[tokenizer.mask_token_id]
* max_a_len
token_ids +=
[tokenizer.sep_token_id]
token_ids +=
(q_token_ids[1:
]+ p_token_ids[1:
-1])
# [mask][mask][sep]問題[sep]篇章
token_ids = tokenizer.encode(tokenizer.convert_ids_to_tokens(token_ids)
, max_length=max_length ,
padding=
"max_length"
, truncation=
true
)# [cls][mask][mask][sep]問題[sep]篇章[sep]
segment_ids =[0
]*len(token_ids)
token_ids = torch.as_tensor(token_ids)
segment_ids = torch.as_tensor(segment_ids)
a_token_ids = torch.as_tensor(a_token_ids)
return
[token_ids, segment_ids]
, a_token_ids
**如下(示例):
首先,例項化
data = mydataset(train_data)
輸出一下結果
這裡自己重寫了dataloader,有需要學習dataloader載入本地資料的,可以仿照寫就可以了。
PyTorch學習 安裝PyTorch
例如,使用的是 windows 系統,想用 pip 安裝,python 是 3.6 版的,沒有 gpu 加速,那就按上面的選,然後根據上面的提示,在 terminal 中輸入以下指令就好了 pip3 install torch 1.3.1 cpu torchvision 0.4.2 cpu ftor...
Pytorch 通過pytorch實現線性回歸
linear regression 線性回歸是分析乙個變數與另外乙個 多個 變數之間關係的方法 因變數 y 自變數 x 關係 線性 y wx b 分析 求解w,b 求解步驟 1.確定模型 2.選擇損失函式 3.求解梯度並更新w,b 此題 1.model y wx b 下為 實現 import tor...
PyTorch入門(三)PyTorch常用操作
def bilinear kernel in channels,out channels,kernel size return a bilinear kernel tensor tensor in channels,out channels,kernel size,kernel size 返回雙線性...