TensorFlow2 0之RNN情感分類問題實戰

2021-10-14 09:43:40 字數 4538 閱讀 2416

tensorflow2.0之rnn情感分類問題實戰

import tensorflow as tf

from tensorflow import keras

from tensorflow.keras import sequential, layers, datasets, optimizers, losses

import numpy as np

import os

tf.random.set_seed(22)

np.random.seed(22)

os.environ[

'tf_cpp_min_log_level']=

'2'assert tf.__version__.startswith(

'2.'

)# 載入資料

# 批大小

batchsize =

128# 詞彙表大小

total_words =

10000

# 句子最大長度s,大於的句子部分將截斷,小於的將填充

max_review_len =

80# 詞向量特徵長度n

embedding_len =

100# 載入imdb資料集,此處的資料採用數字編碼,乙個數字代表乙個單詞

(x_train, y_train)

,(x_test, y_test)

= keras.datasets.imdb.load_data(num_words=total_words)

# 列印輸入的形狀,標籤的形狀

# print(x_train.shape, len(x_train[0]), y_train.shape) # (25000,) 218 (25000,)

# print(x_test.shape, len(x_test[0]), y_test.shape) # (25000,) 68 (25000,)

# 數字編碼表

word_index = keras.datasets.imdb.get_word_index(

)# 列印出編碼表的單詞和對應的數字

# for k, v in word_index.items():

# print(k, v)

# 前面四個id是特殊位

word_index =

# 填充標誌

word_index[""]

=0# 起始標誌

word_index[""]

=1# 未知單詞的標誌

word_index[""]

=2# 沒有用到單詞的標誌

word_index[""]

=3# 翻轉編碼表

reverse_word_index =

dict([

(value, key)

for(key, value)

in word_index.items()]

)# 將數字編碼的句子轉換為字串資料

defdecode_review

(text)

:return

''.join(

[reverse_word_index.get(i,

'?')

for i in text]

)decode_review(x_train[0]

)# print(x)

# 截斷和填充句子,使得等長,此處長句子保留句子後面的部分,短句子在前面填充

x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=max_review_len)

x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen=max_review_len)

# 構建資料集,打散,批量,並丟掉最後乙個不夠batchsize的batch

db_train = tf.data.dataset.from_tensor_slices(

(x_train, y_train)

)db_train = db_train.shuffle(

1000

).batch(batchsize, drop_remainder=

true

)db_test = tf.data.dataset.from_tensor_slices(

(x_test, y_test)

)db_test = db_test.batch(batchsize, drop_remainder=

true

)# 統計資料集屬性

# print('x_train shape:', x_train.shape, tf.reduce_max(y_train), tf.reduce_min(y_train))

# print('x_test shape:', x_test.shape)

# 網路模型

class

myrnn

(keras.model)

:# cell方式構建多層網路

def__init__

(self, units)

:super

(myrnn, self)

.__init__(

)# [b, 64], 構建cell初始化狀態向量,重複使用

self.state0 =

[tf.zeros(

[batchsize, units])]

self.state1 =

[tf.zeros(

[batchsize, units])]

# 詞向量編碼[b, 80] => [b, 80, 100]

self.embedding = layers.embedding(total_words, embedding_len, input_length=max_review_len)

# 構建兩個cell,使用dropout技術防止過擬合

self.rnn_cell0 = layers.******rnncell(units, dropout=

0.5)

self.rnn_cell1 = layers.******rnncell(units, dropout=

0.5)

# 構建分類網路,用於將cell的輸出特徵進行分類,2分類

# [b, 80, 100] => [b, 64] => [b, 1]

self.outlayer = sequential(

[ layers.dense(units)

, layers.dropout(rate=

0.5)

, layers.relu(),

layers.dense(1)

])defcall

(self, inputs, training=

none):

# [b, 80]

x = inputs

# 獲取詞向量:[b, 80] => [b, 80, 100]

x = self.embedding(x)

# 通過2個rnn cell,[b, 80, 100] => [b, 64]

state0 = self.state0

state1 = self.state1

# word:[b, 100]

for word in tf.unstack(x, axis=1)

: out0, state0 = self.rnn_cell0(word, state0, training)

out1, state1 = self.rnn_cell1(out0, state1, training)

# 末層最後乙個輸出作為分類網路的輸入:[b, 64] => [b, 1]

x = self.outlayer(out1, training)

# 通過啟用函式,p(y is pos|x)

prob = tf.sigmoid(x)

return prob

defmain()

:# rnn 狀態向量長度n

units =

64# 訓練世代

epochs =

50# 建立模型

model = myrnn(units)

# 裝配

model.

compile

(optimizer=optimizers.adam(1e-

3), loss=losses.binarycrossentropy(

), metrics=

['accuracy'],

experimental_run_tf_function=

false

)# 訓練和驗證

model.fit(db_train, epochs=epochs, validation_data=db_test)

# 測試

model.evaluate(db_test)

if __name__ ==

'__main__'

: main(

)

tensorflow2 0之one hot函式使用

先了解一下one hot要幹啥吧。來,咱先看個程式,你一定會很眼熟的。嘿,是不是發現什麼了?labels向量最後可以表示乘矩陣的方式,且 1 0 0 表示0,類推 0 1 0 0 表示1,這樣,可以表示 0 9總共九個數值。one hot的作用就是這樣的,作為儲存標籤的一種方式,用1的位置不同來區分...

Tensorflow2 0之卷積層實現

在 tensorflow 中,通過tf.nn.conv2d 函式可以方便地實現2d 卷積運算。tf.nn.conv2d基於輸入?和卷積核?進行卷積運算,得到輸出?其中?表示輸入通道數,表示卷積核的數量,也是輸出特徵圖的通道數。例如 in 1 x tf.random.normal 2 5,5 3 模擬...

tensorflow2 0視訊記憶體設定

遇到乙個問題 新買顯示卡視訊記憶體8g但是tensorflow執行的時候介面顯示只有約6.3g的視訊記憶體可用,如下圖 即限制了我的視訊記憶體,具體原因為什麼我也不知道,但原來的視訊記憶體小一些的顯示卡就沒有這個問題。目前的解決辦法是 官方文件解決 然後對應的中文部落格 總結一下,就是下面的兩個辦法...