官網預設定義如下:
one_hot(indices, depth, on_value=none, off_value=none, axis=none, dtype=none, name=none)
該函式的功能主要是轉換成one_hot型別的張量輸出。引數功能如下:
1)indices中的元素指示on_value的位置,不指示的地方都為off_value。indices可以是向量、矩陣。
2)depth表示輸出張量的尺寸,indices中元素預設不超過(depth-1),如果超過,輸出為[0,0,···,0]
3)on_value預設為1
4)off_value預設為0
5)dtype預設為tf.float32
下面用幾個例子說明一下:
1. indices是向量
1import
tensorflow as tf
23 indices = [0,2,3,5]
4 depth1 = 6 #
indices沒有元素超過(depth-1)
5 depth2 = 4 #
indices有元素超過(depth-1)
6 a =tf.one_hot(indices,depth1)
7 b =tf.one_hot(indices,depth2)89
with tf.session() as sess:
10print('
a = \n
',sess.run(a))
11print('
b = \n
',sess.run(b))
執行結果:
# 輸入是一維的,則輸出是乙個二維的a =[[1. 0. 0. 0. 0. 0.]
[0. 0. 1. 0. 0. 0.]
[0. 0. 0. 1. 0. 0.]
[0. 0. 0. 0. 0. 1.]] # shape=(4,6)
b =[[1. 0. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]
[0. 0. 0. 0.]] # shape=(4,4)
2. indices是矩陣
1import
tensorflow as tf
23 indices = [[2,3],[1,4]]
4 depth1 = 9 #
indices沒有元素超過(depth-1)
5 depth2 = 4 #
indices有元素超過(depth-1)
6 a =tf.one_hot(indices,depth1)
7 b =tf.one_hot(indices,depth2)89
with tf.session() as sess:
10print('
a = \n
',sess.run(a))
11print('
b = \n
',sess.run(b))
執行結果:
# 輸入是二維的,則輸出是三維的a =[[[0. 0. 1. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 1. 0. 0. 0. 0. 0.]]
[[0. 1. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 0. 0.]]] # shape=(2,2,9)
b =[[[0. 0. 1. 0.]
[0. 0. 0. 1.]]
[[0. 1. 0. 0.]
[0. 0. 0. 0.]]] # shape=(2,2,4)
tensorflow中tfrecords使用介紹
這篇文章主要講一下如何用tensorflow中的標準資料讀取方式簡單的實現對自己資料的讀取操作 主要分為以下兩個步驟 1 將自己的資料集轉化為 xx.tfrecords的形式 2 在自己的程式中讀取並使用.tfrecords進行操作 資料集轉換 為了便於講解,我們簡單製作了乙個資料,如下圖所示 程式...
Tensorflow中dynamic rnn的用法
1 api介面dynamic rnn cell,inputs,sequence length none,initial state none,dtype none,parallel iterations none,swap memory false,time major false,scope no...
TensorFlow中遮蔽warning的方法
tensorflow的日誌級別分為以下三種 tf cpp min log level 1 預設設定,為顯示所有資訊 tf cpp min log level 2 只顯示error和warining資訊 tf cpp min log level 3 只顯示error資訊 所以,當tensorflow出...