tensorflow2的資料載入

2021-10-23 01:43:19 字數 2821 閱讀 3821

對於一些小型常用的資料集,tensorflow有相關的api可以呼叫——keras.datasets

經典資料集:

1、boston housing- 波士頓房價

2、mnist/fasion mnist- 手寫數字集/時髦品集

3、cifar10/100- 物象分類

4、imdb- 電影評價

使用 tf.data.dataset 的好處:

1.既能讓後面有迭代的方式,又能直接對資料(tensor型別)進行預處理,還能支援batch和多執行緒的方式處理

2.提供了 .shuffle(打散), .map(預處理) 功能

(x, y)

,(x_val, y_val)

= datasets.fashion_mnist.load_data(

)

用於打亂資料集但不影響對映關係

from_tensor_slices的方法見這篇部落格

db = tf.data.dataset.from_tensor_slices(

(x_test, y_test)

)db = db.shuffle(

10000

)# x_test,y_test對映關係不變

用於使用預處理對映

tf.cast的用法見這篇部落格

tf.one_hot的用法見這篇部落格

def

preprocess

(x,y)

:# 定義乙個與處理函式 用於將numpy資料型別轉化為tensor的型別(dtype=float32)

x = tf.cast(x, dtype=tf.float32)

/255

# 將灰度級歸一化

y = tf.cast(y, dtype=tf.float32)

y = tf.one_hot(y, depth=10)

# 對數字編碼 y 進行one_hot編碼,10個0-1序列中只有乙個1

return x, y

db2 = db.

map(preprocess)

res =

next

(iter

(db2)

)# iter(db2):取得db2的迭代器,next(iter(db2)):迭代

批處理

db3 = db2.batch(32)

# (32張,32個label)為乙個batch

res =

next

(iter

(db3)

)# 進行迭代

res[0]

.shape, res[1]

.shape # 分別是乙個batch中格式與label格式的shape

(tensorshape([32

,32,32

,3])

, tensorshape([32

,1,10

]))# 格式是(32張,32*32大小,3個通道) # (32張對應的label,1個label——通常會squeeze掉,10個one_hot深度)

整個資料集的迴圈次數

db4 = db3.repeat(

)# 這樣就是一直repeat迭代,死迴圈

db4 = db3.repeat(2)

# 這個是迭代2次

def

prepare_mnist_features_and_labels

(x,y)

: x = tf.cast(x, tf.float32)

/255.0

y = tf.cast(y, tf.float64)

return x,y

defmnist_dataset()

:(x, y)

,(x_val, y_val)

= datasets.fashion_mnist.load_data(

)# 1.載入影象資料和通用資料(val指的是validation,測試資料集)

y = tf.one_hot(y, depth=10)

# 2.資料 one_hot編碼

y_val = tf.one_hot(y_val, depth=10)

# label one_hot編碼

ds = tf.data.dataset.from_tensor_slices(

(x, y)

)# 3.轉換為dataset型別

ds = ds.

map(prepare_mnist_features_and_labels)

# 4.預處理函式對映

ds = ds.shuffle(

60000

).batch(

100)

# 5.其他處理——如本處的前60000個打亂,100個為乙個批次

ds_val = tf.data.dataset.from_tensor_slices(

(x_val, y_val)

) ds_val = ds_val.

map(prepare_mnist_features_and_labels)

ds_val = ds_val.shuffle(

10000

).batch(

100)

return ds, ds_val

TensorFlow2學習八之資料增強

影象增強 對影象的簡單形變。tensorflow2影象增強函式tf.keras.preprocessing.image.imagedatagenerator image gen train tf.keras.preprocessing.image.imagedatagenerator rescale...

Tensorflow2 自動求導機制

tensorflow 引入了 tf.gradienttape 這個 求導記錄器 來實現自動求導 如何使用 tf.gradienttape 計算函式 y x x 2 在 x 3 時的導數 import tensorflow as tf x tf.variable initial value 3.初始化...

tensorflow2建立卷積核Conv2D函式

使用conv2d可以建立乙個卷積核來對輸入資料進行卷積計算,然後輸出結果,其建立的卷積核可以處理二維資料。依次類推,conv1d可以用於處理一維資料,conv3d可以用於處理三維資料。在進行神經層級整合時,如果使用該層作為第一層級,則需要配置input shape引數。在使用conv2d時,需要配置...