tf.data是tensorflow2.0中加入的資料載入模組,是乙個非常便捷的處理資料的模組。
這裡簡單介紹一些tf.data的使用方法。
1.載入tensorflow中自帶的mnist資料
並對資料進行一些簡單的處理
1 (train_image, train_label), (test_image, test_label) =tf.keras.datasets.mnist.load_data()2 train_image = train_image / 255
3 test_image = test_image / 255
2.使用tf.data.dataset.from_tensor_slices()方法對資料進行切片處理
該函式是dataset核心函式之一,它的作用是把給定的元組、列表和張量等資料進行特徵切片。切片的範圍是從最外層維度開始的。如果有多個特徵進行組合,那麼一次切片是把每個組合的最外維度的資料切開,分成一組一組的。
1 ds_train_label =tf.data.dataset.from_tensor_slices(train_label)2 ds_train_label = tf.data.dataset.from_tensor_slices(train_label)
3.使用tf.data.dataset.zip()方法將image和label資料合併
tf.data.dataset.zip()方法可將迭代物件中相對應(例如image對應label)的資料打包成乙個元組,返回由這些元組組成的物件。
1 ds_train = tf.data.dataset.zip((ds_train_image, ds_train_label))
這裡ds_train中的資料就是由許多個(image, label)元組組成的。
事實上我們也可以直接把train_image與train_label進行合併,以元組的形式對train_image和train_label進行切片即可。
1 ds_trian = tf.data.dataset.from_tensor_slices((train_image, train_label))
4.使用.shuffle().repeat().batch()方法對資料進行處理
1 ds_train = ds_train.shuffle(10000).repeat(count = 3).batch(64)
.shuffle()作用是將資料進行打亂操作,傳入引數為buffer_size,改引數為設定「打亂快取區大小」,也就是說程式會維持乙個buffer_size大小的快取,每次都會隨機在這個快取區抽取一定數量的資料。
.repeat()作用就是將資料重複使用多少次,引數是重複的次數,若無引數則無限重複。
.batch()作用是將資料打包成batch_size, 每batch_size個資料打包在一起作為乙個epoch。
5.注意事項
在使用tf.data時,如果不設定資料的.repeat()的重複次數,資料會無限制重複,如果把這樣的資料直接輸入到神經網路中會導致記憶體不足程式無法終止等錯誤。此時,要在.fit()方法中加以限制。
1 history = model.fit(ds_train, epochs = 5, steps_per_epoch =step_per_epochs,2 validation_data = ds_test, validation_steps = 10000 // 64
3 )
使用steps_per_epoch引數限制每個epochs的資料量。
使用validation_steps限制驗證集中的資料量。
tensorflow學習筆記
tensorflow安裝可以直接通過命令列或者原始碼安裝,在此介紹tensorflow8命令列安裝如下 安裝tensorflow sudo pip install upgrade 另外,解除安裝tensorflow命令為 sudo pip uninstall tensorflow tensorflo...
Tensorflow學習筆記
1.如何在虛擬機器中安裝tensor flow 1 首先安裝pip pip install 2 pip install 2.學習tensorflow需要學習 python and linux 3.使用 tensorflow,你必須明白 tensorflow 1 使用圖 graph 來表示計算任務.2...
TensorFlow學習筆記
1 擬合直線 import the library import tensorflow as tf import numpy as np prepare train data train x np.linspace 1,1,100 temp1 train x,temp2 train x.shape,...