乙個簡單的分類問題案例的pytorch實現

2021-10-24 02:35:00 字數 3338 閱讀 8629

import torch

import torch.nn.functional as f

import matplotlib.pyplot as plt

n_data = torch.ones(

100,2)

# torch.normal(mean, std, *, generator=none, out=none) → tensor

# mean 是乙個張量,每個輸出元素的正態分佈均值

# std 是乙個張量,每個輸出元素的正態分佈的標準偏差

x0 = torch.normal(

2*n_data,1)

y0 = torch.zeros(

100)

# 第乙個分類標籤為0

x1 = torch.normal(-2

*n_data,1)

y1 = torch.ones(

100)

# 第二個分類標籤為1

x = torch.cat(

(x0, x1),0

).type

(torch.floattensor)

y = torch.cat(

(y0, y1),)

.type

(torch.longtensor)

# matplotlib.pyplot.scatter(x, y, s=none, c=none, marker=none, cmap=none,

# norm=none, vmin=none, vmax=none, alpha=none, linewidths=none,

# verts=none, edgecolors=none, *, data=none, **kwargs)

# x, y: 表示的是大小為(n,)的陣列,也就是我們繪製散點圖的資料點,輸入資料。

# s: 是乙個實數或者是乙個陣列大小為(n,),可選,預設為20。點的面積。

# c: 表示的是顏色,可選。預設是藍色'b',表示的是標記的顏色,

# 或者是乙個表示顏色的字元,或者是乙個長度為n的表示顏色的序列等等,

# 但是c不可以是乙個單獨的rgb數字,也不可以是乙個rgba的序列。可以是他們的二維陣列(只有一行)。

# marker: 表示的標記的樣式,可選,預設的是'o'

# cmap: colormap,標量或者是乙個colormap的名字,cmap僅僅當c是乙個浮點數陣列的時候才使用。

# 如果沒有申明就是image.cmap,可選,預設為none。

# norm: normalize,資料亮度在0-1之間,也是只有c是乙個浮點數陣列的時候才使用。

# 如果沒有申明就是color.normalize,就是預設none

# vmin,vmax: 標量,當norm存在的時候忽略。用來進行亮度資料的歸一化,可選,預設none。

# alpha: 標量,0-1之間,可選,預設none,線的透明度

# linewidths: 也就是標記點的長度,預設none,即點的直徑

# plt.scatter(x.data[:,0], x.data[:,1], c=y.data, s=100, lw=0, cmap='rdylgn')

# plt.show()

class

net(torch.nn.module)

:def

__init__

(self, n_data, n_hidden, n_output)

:super

(net, self)

.__init__(

) self.hidden = torch.nn.linear(n_data, n_hidden)

self.predict = torch.nn.linear(n_hidden, n_output)

defforward

(self, x)

: x = torch.relu(self.hidden(x)

)#使用relu作為啟用函式

x = self.predict(x)

return x

net = net(2,

10,2)

# 類的例項化

print

(net)

plt.ion(

)plt.show(

)optimizer = torch.optim.sgd(net.parameters(

), lr=

0.03

)# 使用隨機梯度下降作為優化器

loss_function = torch.nn.crossentropyloss(

)# 使用交叉熵誤差來作為損失函式

for t in

range

(100):

out = net(x)

# 正向傳播

loss = loss_function(out, y)

# 計算損失值

optimizer.zero_grad(

)# 優化器清零

loss.backward(

)# 反向傳播求梯度

optimizer.step(

)# 更新可訓練引數

if t %2==

0:plt.cla(

)# clear axis即清除當前圖形中的當前活動軸。其他軸不受影響

prediction = torch.

max(f.softmax(out, dim=0)

,1)[

1]# 輸出softmax計算之後的最大值的索引

pred_y = prediction.data.squeeze(

) target_y = y.data

plt.scatter(x.data[:,

0], x.data[:,

1], c=pred_y, s=

100, lw=

0, cmap=

'rdylgn'

)# 繪製散點圖

accuracy =

sum(pred_y==target_y)

/200

# 計算識別精度

plt.text(

1.5,-4

,'accuracy=%.2f'

% accuracy, fontdict=

)# 輸出乙個實時的精度計算文字

plt.pause(

0.05

)# 0.05秒的暫停時間

plt.ioff(

)plt.show(

)

這段**執行之後在plt.pause(0.05)處會報錯,原因還未搞懂,希望各位大佬不吝賜教,謝謝!

乙個簡單的動態規劃問題 小偷案例

前言 一 案例描述 二 問題分析 三 示例 總結動態規劃是一種演算法技巧,先舉乙個例子 如何讓乙個四歲的小孩理解動態規劃的思路?國外友人有這樣乙個例子 列出乙個1 1 1 1 1 1 1 1 的式子,讓小孩回答,小孩思索數秒後會告訴你答案是8。隨後在前面再多寫乙個 1,再提問答案是多少,小孩會瞬間告...

乙個簡單的sql審核案例

今天開發的同學發來一封郵件,希望我幫忙對乙個sql語句做乙個評估。他們也著急要用,但是為了穩妥起見,還是希望我來審核一下,這是乙個好的習慣。開啟郵件,看到的語句是下面這樣的形式。select a.cout1 b.cout2 from select count as cout1 from test o...

乙個案例的簡單總結

翻看去年處理的乙個案例,發現處理時間挺長的,而且這個案例也有點意思,就再看多兩眼,做個簡單總結。1.首先是應用伺服器效能不穩定,排查之後,伺服器是vm,要求加資源,並且所有資源都reserved.2.接著就是應用伺服器連線資料庫時很不穩定,資料庫經常報 recovery mode 好像是資料庫莫名被...