以前直接用的是sklearn或者tensorflow提供的mnist資料集,已經轉換為矩陣形式的資料格式。但是sklearn體用的資料集合並不全,一共只有3000+圖,每個圖是8*8的大小,但是原始資料並不是這樣的。
mnist資料集合的原始**為:
進入官網,發現有4個檔案,分別對應訓練集、測試集的影象和標籤:
官網給的資料集合並不是原始的影象資料格式,而是編碼後的二進位制格式:
影象的編碼為:
典型的head+data模式:前16個位元組分為4個整型資料,每個4位元組,分別代表:資料資訊des、影象數量(img_num),影象行數(row)、影象列數(col),之後的資料全部為畫素,每row*col個畫素構成一張圖,每個色素的值為(0-255)。
標籤的編碼為:
模式和前面的一樣,不同的是head只有8位元組,分別為des和標籤的數量(label_num).之後每乙個位元組代表乙個標籤,值為(0-9)。
弄清楚編碼後,就可以直接上**了:
import numpy as np
import struct
mnist_dir = r'./digit/'
deffetch_mnist
(mnist_dir,data_type):
train_data_path = mnist_dir + 'train-images.idx3-ubyte'
train_label_path = mnist_dir + 'train-labels.idx1-ubyte'
test_data_path = mnist_dir + 't10k-images.idx3-ubyte'
test_label_path = mnist_dir + 't10k-labels.idx1-ubyte'
# train_img
with open(train_data_path, 'rb') as f:
data = f.read(16)
des,img_nums,row,col = struct.unpack_from('>iiii', data, 0)
train_x = np.zeros((img_nums, row*col))
for index in range(img_nums):
data = f.read(784)
if len(data) == 784:
train_x[index,:] = np.array(struct.unpack_from('>' + 'b' * (row * col), data, 0)).reshape(1,784)
f.close()
# train label
with open(train_label_path, 'rb') as f:
data = f.read(8)
des,label_nums = struct.unpack_from('>ii', data, 0)
train_y = np.zeros((label_nums, 1))
for index in range(label_nums):
data = f.read(1)
train_y[index,:] = np.array(struct.unpack_from('>b', data, 0)).reshape(1,1)
f.close()
# test_img
with open(test_data_path, 'rb') as f:
data = f.read(16)
des, img_nums, row, col = struct.unpack_from('>iiii', data, 0)
test_x = np.zeros((img_nums, row * col))
for index in range(img_nums):
data = f.read(784)
if len(data) == 784:
test_x[index, :] = np.array(struct.unpack_from('>' + 'b' * (row * col), data, 0)).reshape(1, 784)
f.close()
# test label
with open(test_label_path, 'rb') as f:
data = f.read(8)
des, label_nums = struct.unpack_from('>ii', data, 0)
test_y = np.zeros((label_nums, 1))
for index in range(label_nums):
data = f.read(1)
test_y[index, :] = np.array(struct.unpack_from('>b', data, 0)).reshape(1, 1)
f.close()
if data_type == 'train':
return train_x, train_y
elif data_type == 'test':
return test_x, test_y
elif data_type == 'all':
return train_x, train_y,test_x, test_y
else:
print('type error')
if __name__ == '__main__':
tr_x, tr_y, te_x, te_y = fetch_mnist(mnist_dir,'all')
import matplotlib.pyplot as plt # plt 用於顯示
img_0 = tr_x[59999,:].reshape(28,28)
plt.imshow(img_0)
print(tr_y[59999,:])
img_1 = te_x[500,:].reshape(28,28)
plt.imshow(img_1)
print(te_y[500,:])
plt.show()
執行結果:
簡單 mnist 資料集轉為csv格式讀取
對於剛入門ai的童鞋來說,mnist 資料集就相當於剛接觸程式設計時的 hello world 一樣,具有別樣的意義,後續許多機器學習的演算法都可以用該資料集來進行簡單測試。也給出了資料集的格式,但是要手動解析這些資料也是有點複雜的。以下 的功能是將訓練集和訓練標籤整合到乙個csv檔案裡 測試檔案同...
MNIST資料集的處理
1 mnist資料集介紹 資料格式介紹 2 資料讀取 mnist資料集的讀取比較複雜,這裡給出兩種讀取方式。2.1 struct包讀取資料 nn網路中使用的讀取方法 2.2 torch.version和torch.utils.data.dataloader處理資料 import torch from...
MNIST資料集介紹
mnist資料集包含了6w張作為訓練資料,1w作為測試資料。在mnist資料集中,每一張都代表了0 9中的乙個數字,的大小都是28 28,且數字都會出現在的正中間。資料集包含了四個檔案 t10k images idx3 ubyte.gz 測試資料 t10k labels idx1 ubyte.gz ...