TensorFlow中的GNMT模型構建大致過程

2021-10-03 03:01:40 字數 2558 閱讀 2425

tensorflow中gnmt的實現採取多層雙向lstm構建,構建基本過程如下:

encoder:

bi_output, bi_state = tf.nn.bidirectional_dynamic_rnn(

tf.nn.bidirectional_dynamic_rnn(fw_cell,bw_cell,inputs)

其中fw_cell、bw_cell和下方組裝的multirnncell相同,inputs=encoder_emb_inp(詞嵌入),而bi_output作為dynamic_rnn的引數,bi_state和encoder_state組成最終輸入的encoder_state

單層:

(1) tf.contrib.rnn.basiclstmcell(num_units,forget_bias=forget_bias)

(2) tf.contrib.rnn.grucell(num_units)

(3) tf.contrib.rnn.layernormbasiclstmcell

多層:

通過for i in range(num_layers)組建cell_list,並在layer數大於1時返回tf.contrib.rnn.multirnncell(cell_list)

建立好上述多層雙向結構後,推入rnn運算 encoder_outputs, encoder_state = tf.nn.dynamic_rnn(cell_list,bi_output..)。返回結果encoder_outputs和encoder_state作為decoder的輸入

decoder:

輸入:encoder_outputs, encoder_state,decoder_emb_inp(翻譯目標嵌入向量)

模型:

tf.contrib.seq2seq.traininghelper(decoder_emb_inp,length,...) //獲取原文對應的翻譯結果

my_decoder = tf.contrib.seq2seq.basicdecoder(cell_list, helper, decoder_initial_state)

outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(my_decoder,...)

從outputs中獲取sample_id和rnn_output,其中sample_id就是翻譯結果對應詞典中的id,rnn_output就是最後乙個隱層h,通過tf.dense(rnn_output)獲取logits

loss:

crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(

labels=target_output, logits=logits)

// 最終的loss還需要去掉毛重(填充部分mask掉),和batch_size平均一下

target_weights = tf.sequence_mask(

self.iterator.target_sequence_length, max_time, dtype=logits.dtype)

loss = tf.reduce_sum(crossent * target_weights) / tf.to_float(self.batch_size)

訓練方向:

有了loss以後,定義梯度下降的目標是使loss變低,從而反向更新loss計算過程中用到的各層網路權值,定義過程:                  

opt = tf.train.gradientdescentoptimizer(learning_rate)    // 也有其它選項如adam等

gradients = tf.gradients(self.train_loss,...)

// 其中max_gradient_norm梯度裁剪用,預設是5.0,防止梯度**

clipped_gradients, gradient_norm = tf.clip_by_global_norm(gradients, max_gradient_norm)

zip(clipped_gradients, params), global_step=self.global_step)

和訓練過程的decoder比較相似,模型**於訓練記錄的結果,只有兩點不同

1、資料使用推斷資料

2、加入集束搜尋(beam_width)

tf.contrib.seq2seq.traininghelper + tf.contrib.seq2seq.basicdecoder 轉變為beamsearchdecoder 

--------->  my_decoder =  tf.contrib.seq2seq.beamsearchdecoder 

Tensorflow中的Lazy load問題

用tensorflow訓練或者inference模型的時候,有時候會遇到執行越來越慢,最終記憶體被佔滿,導致電腦宕機的問題,我們稱之為記憶體溢位。出現這種問題很可能是因為在乙個session中,graph迴圈建立重複的節點所導致的lazy load問題。舉個例子,用tensorflow迴圈做多次加法...

tensorflow中的函式

執行當前tensor的run 操作 a tf.variable tf.ones 10 with tf.session as sess tf.global variables initializer run b sess.run a 1 b a eval 2 print b 中的 1 2 行的功能一樣...

tensorflow中的global step引數

global step在滑動平均 優化器 指數衰減學習率等方面都有用到,這個變數的實際意義非常好理解 代表全域性步數,比如在多少步該進行什麼操作,現在神經網路訓練到多少輪等等,類似於乙個鐘錶。global step經常在滑動平均,學習速率變化的時候需要用到,這個引數在tf.train.gradien...