1.首先想到的是用tf.placeholder()載入不同的資料來進行計算,比如
def inference(input_):
"""this is where you put your graph.
the following is just an example.
"""conv1 = tf.layers.conv2d(input_)
conv2 = tf.layers.conv2d(conv1)
return conv2
input_ = tf.placeholder()
output = inference(input_)
...calculate_loss_op = ...
train_op = ...
...with tf.session() as sess:
sess.run([loss, train_op], feed_dict=)
if validation == true:
sess.run([loss], feed_dict=)
這種方式很簡單,也很直接了然。
2.但是,如果處理的資料量很大的時候,使用 tf.placeholder() 來載入資料會嚴重地拖慢訓練的進度,因此,常用tfrecords檔案來讀取資料。
此時,很容易想到,將不同的值傳入inference()函式中進行計算。
train_batch, label_batch = decode_train()
val_train_batch, val_label_batch = decode_validation()
train_result = inference(train_batch)
...loss = ..
train_op = ...
...if validation == true:
val_result = inference(val_train_batch)
val_loss = ..
with tf.session() as sess:
sess.run([loss, train_op])
if validation == true:
sess.run([val_result, val_loss])
這種方式看似能夠直接呼叫inference()來對驗證資料進行前向傳播計算,但是,實則會在原圖上新增上許多新的結點,這些結點的引數都是需要重新初始化的,也是就是說,驗證的時候並不是使用訓練的權重。
3.用乙個tf.placeholder來控制是否訓練、驗證。
def inference(input_):
......
...return inference_result
train_batch, label_batch = decode_train()
val_batch, val_label = decode_validation()
is_training = tf.placeholder(tf.bool, shape=())
x = tf.cond(is_training, lambda: train_batch, lambda: val_batch)
y = tf.cond(is_training, lambda: train_label, lambda: val_label)
logits = inference(x)
loss = cal_loss(logits, y)
train_op = optimize(loss)
with tf.session() as sess:
loss, _ = sess.run([loss, train_op], feed_dict=)
if validation == true:
loss = sess.run(loss, feed_dict=)
使用這種方式就可以在乙個大圖里建立乙個分支條件,從而通過控制placeholder來控制是否進行驗證。 一邊學習,一邊記錄
1 flutter channels 檢視flutter分支 帶星 表示當前的分支 2 flutter doctor 檢視flutter環境配置完成情況 3 flutter devices 執行時需要的裝置 4 flutter upgrade 更新flutter 5 flutter packages...
如何一邊寬度自適應 一邊寬度固定
一 右側固定寬度 左側自適應 第一種 左側用margin right,右側float right 就可以實現。html 可以如下寫 我是龍恩 我是龍恩 css 可以如下寫 box left box right 如上 就可以實現效果。第二種 左側同樣用margin right 右側採用絕對定位 如下 ...
解決List如何一邊遍歷一邊刪除操作
在老專案中很少技術負責人會你遍歷list的同時再進行刪除操作,因為很容易報錯,或者出現疏漏資料。不規範的 public static void main string args system.out.println platformlist 原因 list中進行操作時會改變modcount的值,而遍...