pytorch 載入資料集

2021-10-25 16:10:21 字數 4284 閱讀 3708

2 tensor 的 構造方式

import torch

import numpy as np

data = np.array([1,2,3])

print(torch.tensor(data))

# 副本

print(torch.tensor(data))

# 副本

print(torch.as_tensor(data))

# 檢視

print(torch.from_numpy(data))

# 檢視

4 torch 的 全域性預設值

5 torch.tensor 的四種構造方式各有利弊

6 torch.tensor 的分量

7 張量的壓縮和解壓縮

data = np.array(

[[1,1,1,1],

[2,2,2,2],

[3,3,3,3]])

t = torch.tensor(data)

print(t.reshape(1,12).squeeze(dim=0))

print(t.reshape(1,12).squeeze(dim=0).shape)

print(t.unsqueeze(dim=2))

print(t.unsqueeze(dim=2).shape)

print(t.reshape(-1))

print(t.flatten())

print(torch.cat((t1,t2),dim=0))
t1 = torch.ones((4,4),dtype = int)

t2 = torch.ones((4,4),dtype = int)*2

t3 = torch.ones((4,4),dtype = int)*3

t = torch.stack((t1,t2,t3))

t = t.reshape(3,1,4,4)

print(t)

t = t.flatten(start_dim=1)

print(t)

np.broadcast_to(2,[2,2])
t = torch.tensor(

[[0,5,7],

[6,0,7],

[0,8,0]],dtype=torch.float32)

print(t.eq(0)) # 0 是否相等

print(t.ge(5)) # 5 大於等於5

print(t.gt(5)) # 5 大於5

print(t.lt(7)) # 7 小於7

print(t.le(7)) # 7 小於等於7

print(t.neg()) # 取相反數

print()

-

13 torch.tensor 張量縮減

t = torch.tensor(np.ndarray)

t.max()

t.sum()

t.argmax() # 取最大值處的索引

t.max(dim=0)

t.argmax(dim=0)

t.mean(dim=1).tolist() # list

t.mean(dim=1).numpy() # convert to a ndarray

-

14 etl

- e extract 提取

- t transform 轉換

- l load 載入

tensor([[0.8410, 0.8758, 0.6165, 0.2126],

[0.0720, 0.9284, 0.4022, 0.1788],

[0.1653, 0.8602, 0.7925, 0.3128]])

-

15 torchvision : 是乙個隊流行的資料集、模型和計算機視覺的影象轉換的訪問的包

- torchvision.transforms :乙個介面,能夠訪問影象處理的通用轉換

-

16 dataset 資料集

-

17 dataloader 資料載入器封裝資料集並提供對底層資料的訪問

class ohlc(dataset):

definit(self,csv_file):

self.data = pd.read_csv(csv_file)

def __getitem__(self,index):

r = self.data.iloc[index]

label = torch.tensor(r.is_up_day,dtype = torch.long)

sample = self.normalize(torch.tensor([r.open,r.high,r.low,r.close]))

return sample,label

del __len__(self):

return len(train_set)

-

18 獲取資料集,轉化,放在資料載入器中

import torch

import torchvision

import torchvision.transforms as transforms

train_set = torchvision.datasets.fashionmnist(

root = 『./data/fashionmnist』,

train=true,

download=true,

transform = transforms.conpose(

[transforms.totensor()

]))

file "", line 1

- 18 獲取資料集,轉化,放在資料載入器中

^syntaxerror: invalid character in identifier

import torch

import torchvision

import torchvision.transforms as transforms

import numpy as np

import matplotlib.pyplot as plt

root = './data/fashionmnist', # 資料集存放的位置

transform = transforms.compose( # 將資料集轉化為我們需要的張量的型別

[transforms.totensor()

]))train_loader = torch.utils.data.dataloader(train_set,batch_size=10) # 在這裡定義了乙個資料載入器物件,引數是資料集和一批的數量

torch.set_printoptions(linewidth=120)

print(train_set.train_labels[:100]) # 列印訓練集標籤的前100個

train_set.train_labels.bincount() # 這裡可以看出訓練集中標籤的各個分類的個數 、、 目前來說一般都是均衡分配的

batch = next(iter(train_loader)) # 這裡使用了迭代器,返回train_loader 的 第乙個元素,也就是第一批

images,labels = batch # 每一批都是tuple型別,images和labels都是tensor型別,images的shape是【10,1,28,28】

grid = torchvision.utils.make_grid(images,nrow=10) # make_grid的作用是將多幅拼成一幅

plt.figure(figsize=(15,15))

plt.imshow(np.transpose(grid,(1,2,0))) // 有關np.transpose()函式的作用請參考

print('labels:',labels)

PyTorch學習 載入資料集

需要定義diabetesdataset做為載入資料集diabetes的類,繼承自dataset,dataset是抽象類,需要實現其中的三個方法,init,getitem,len import torch from torch.utils.data import dataset 抽象類 from to...

pytorch 載入自己的資料集

pytorch 載入自己的資料集,需要寫乙個繼承自torch.utils.data中dataset類,並修改其中的 init 方法 getitem 方法 len 方法。預設載入的都是,init 的目的是得到乙個包含資料和標籤的list,每個元素能找到位置和其對應標籤。然後用 getitem 方法得到...

pytorch載入資料

參考 pytorch深度學習快速入門教程 絕對通俗易懂!小土堆 可看到說明,dataset是乙個抽象類,我們重寫dataset時要繼承這個類,所有的子類都應該重寫 getitem 方法,這個方法作用是獲取資料及對應的labe。同時我們可以選擇性地去重寫 len 方法,其作用是獲取資料集長度。這裡我使...