mnist手寫數字集製作tfrecords資料格式

2021-09-12 09:19:07 字數 2347 閱讀 6522

tfrecords是一種二進位制檔案,可先將與標籤製作成該格式的檔案,使用tfrecords進行資料讀取,會提高記憶體利用率,將不同輸入檔案統一起來。

檔案生成的過程:

具體**如下:

def generate_tfrecord():

i***ists=os.path.exists(data_path) ##判斷儲存路徑是否存在

if not i***ists:

os.makedirs(data_path)

print('路徑建立成功')

else:

print('路徑已存在')

write_tfrecord(tfrecord_train, image_train_path, label_train_path) ##使用自定義函式將訓練集生成名叫tfrecord_train的tfrecords檔案

write_tfrecord(tfrecord_test, image_test_path, label_test_path) ##同理訓練集

def write_tfrecord(tfrecordname, image_path, label_path):

writer=tf.python_io.tfrecordwriter(tfrecordname) ##建立乙個writer

num_pic=0 ##計數器

f=open(label_path,'r') ##以讀的形式開啟標籤檔案

contents = f.readlines() ##讀取整個檔案內容

f.close()

for content in contents:

value=content.split() ##以空格分隔每行的內容,分割後組成列表value

img_path=image_path+value[0]

img=image.open(img_path) ##開啟

img_raw=img.tobytes() ##將轉換為二進位制資料

labels=[0]*10

labels[int(value[1])]=1 ##將labels所對應的標籤為賦值為1

example=tf.train.example(features=tf.train.features(feature=)) ##在labels放入對應的標籤

writer.write(example.serializetostring()) ##將example進行序列化

num_pic+=1

writer.close()

print('tfrecord檔案寫入成功')

檔案讀取的過程:

具體**如下:

###實現了批獲取訓練集或測試集的和標籤

def get_tfrecord(num, istrain=true): ##引數num表示一次讀取多少組

if istrain:

tfrecord_path=tfrecord_train

else:

tfrecord_path=tfrecord_test

img,label=read_tfrecord(tfrecord_path)

img_batch, label_batch= tf.train.shuffle_batch([img, label],

batch_size=num,

capacity=1000,

min_after_dequeue=700,

num_threads=2)

def read_tfrecord(tfrecord_path):

filmname_queue=tf.train.string_input_producer([tfrecord_path])

reader=tf.tfrecordreader() ##新建乙個reader

_,serialized_example = reader.read(filmname_queue) ##將讀出的每乙個樣本儲存到serialized_example中進行解序列化

features=tf.parse_single_example(serialized_example,features=)

img=tf.decode_raw(features['img_raw'],tf.uint8) ##將img_raw字串轉化為8位無符號整型

img.set_shape([784])

img=tf.cast(img,tf.float32)*(1./255) ##轉化為浮點數形式

label=tf.cast(features['label'],tf.float32)

return img, label

keras 實現mnist手寫數字集識別

coding utf 8 classifier mnist import numpy as np np.random.seed 1337 from keras.datasets import mnist from keras.utils import np utils from keras.mode...

mnist手寫數字識別

import tensorflow as tf import numpy as np from tensorflow.contrib.learn.python.learn.datasets.mnist import read data sets mnist read data sets f pyth...

MNIST手寫數字識別 tensorflow

神經網路一半包含三層,輸入層 隱含層 輸出層。如下圖所示 現以手寫數字識別為例 輸入為784個變數,輸出為10個節點,10個節點再通過softmax啟用函式轉化為 值。如下,準確率可達0.9226 import tensorflow as tf from tensorflow.examples.tu...