使用tensorflow 構造乙個神經元,簡單的線性回歸網路。
問題:現有一組有雜訊的樣本資料,共2000個,每乙個樣本 x 有 3 個特徵, 對應乙個標籤 y 值。從資料樣本中學習 y=w
×x+b
y=w\times x + b
y=w×x+
b 中的引數
首先我們來生成樣本資料,w_real 和 b_real 是控制樣本資料的引數的真實值,
x_data = np.random.randn(
2000,4
)w_real =
[0.2
,0.3
,0.1
,0.3
]b_real =
-0.3
noise = np.random.randn(1,
2000)*
0.1y_data = np.matmul(w_real, x_data.t)
+ b_real + noise
編寫神經網路
下面會用到的 tensorflow api
官方tensorflow 文件
全部源**實現:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# 建立資料模擬
x_data = np.random.randn(
2000,4
)w_real =
[0.2
,0.3
,0.1
,0.3
]b_real =
-0.3
noise = np.random.randn(1,
2000)*
0.1y_data = np.matmul(w_real, x_data.t)
+ b_real + noise
# 清除預設圖中的內容
tf.reset_default_graph(
)# 設定步數
num_step =
10# 學習率
learning_rate =
0.5# 建立圖
g = tf.graph(
)# 儲存wb
wb_sess =
with g.as_default():
# x, y_true 佔位符
x = tf.placeholder(tf.float32, name =
'x')
y_true = tf.placeholder(tf.float32, name =
'y_true'
)# w, b 變數
w = tf.variable([[
0,0,
0,0]
], dtype = tf.float32, name =
'w')
b = tf.variable(
0, dtype = tf.float32, name =
'b')
# **值 y = w * x + b
y_pred = tf.add(tf.matmul(w, tf.transpose(x)
), b, name =
'y_pred'
)# 損失 計算成員平均值
loss = tf.reduce_mean(tf.square(y_true - y_pred)
, name =
'loss'
)# 優化器,sgd
optimizer = tf.train.gradientdescentoptimizer(learning_rate, name=
'sgd'
) train = optimizer.minimize(loss, name =
'train'
)# 全域性初始化節點
init = tf.global_variables_initializer(
)with tf.session(
)as sess:
sess.run(init)
for step in
range
(num_step)
: sess.run(train,
)
[w, b]))
if(step %5==
0):print
(step +
1, sess.run(
[w, b]))
print
(num_step, sess.run(
[w, b]
))
總結:
這只會讓你了解tensorflow的一些api 特性,加強使用這些api,簡單模型。
用TensorFlow實現iris資料集線性回歸
本文將遍歷批量資料點並讓tensorflow更新斜率和y截距。這次將使用scikit learn的內建iris資料集。特別地,我們將用資料點 x值代表花瓣寬度,y值代表花瓣長度 找到最優直線。選擇這兩種特徵是因為它們具有線性關係,在後續結果中將會看到。本文將使用l2正則損失函式。用tensorflo...
用Tensorflow完成簡單的線性回歸模型
思路 在資料上選擇一條直線y wx b,在這條直線上附件隨機生成一些資料點如下圖,讓tensorflow建立回歸模型,去學習什麼樣的w和b能更好去擬合這些資料點。1 隨機生成1000個資料點,圍繞在y 0.1x 0.3 周圍,設定w 0.1,b 0.3,屆時看構建的模型是否能學習到w和b的值。imp...
tensorflow實現簡單的卷積網路
import tensorflow as tf import gc from tensorflow.examples.tutorials.mnist import input data mnist input data.read data sets f zxy python mnist data o...