實驗 鳶尾花分類 簡單的神經網路

2022-09-22 05:51:07 字數 1997 閱讀 8426

import

torch

from torch import

nnfrom sklearn.datasets import

load_iris

from sklearn.model_selection import

train_test_split

import

numpy as np

import

matplotlib.pyplot as plt

x = torch.tensor(load_iris().data, dtype=torch.float32)

y = torch.tensor(load_iris().target, dtype=torch.long)

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

匯入鳶尾花資料集,這裡注意資料和標籤型別的設定:dtype=torch.float32,dtype=torch.long,否則會報錯

net = nn.sequential(nn.linear(4, 10), 

nn.relu(),

nn.linear(10, 10),

nn.relu(),

nn.linear(10, 3))

definit_weights(m):

if type(m) ==nn.linear:

nn.init.normal_(m.weights, std=0.01)

loss = nn.crossentropyloss(reduction="

none

")

trainer = torch.optim.adam(net.parameters(), lr=0.05)

train_loss =

test_loss =

train_l =sum(loss(net(x_train), y_train)).detach().numpy()

test_l =sum(loss(net(x_test), y_test)).detach().numpy()

epochs = 1000

for i in

range(epochs):

trainer.zero_grad()

l =sum(loss(net(x_train), y_train))

l.backward()

trainer.step()

l =sum(loss(net(x), y))

train_l =sum(loss(net(x_train), y_train)).detach().numpy()

test_l =sum(loss(net(x_test), y_test)).detach().numpy()

epoch_index = range(epochs + 1)

plt.plot(epoch_index, train_loss,

'green

', epoch_index, test_loss, '

blue

')

plt.show()

使用交叉熵損失函式時, 定義神經網路架構的時候不需要用softmax !     (我一開始在神經網路最後一層加了nn.softmax有報錯)關於交叉熵損失函式,nn.crossentropyloss(),有一些需要注意的點

貼篇網上介紹的部落格,後面看自己有沒有時間總結下。

有些場合(例如用matplotlib繪圖)需要用numpy的陣列,使用能求梯度的tensor是會報錯的!

這裡用.detach().numpy()來完成,例子可以見上面的**

實驗結果:

神經網路鳶尾花分類

import tensorflow as tf import numpy as np from sklearn import datasets import pandas as pd import matplotlib.pyplot as plt 從sklearn包datasets中讀入資料集 返回...

神經網路學習案例 鳶尾花分類問題

結合 機器學習實戰 和泰迪杯師資培訓,使用神經網路完成了鳶尾花問題,將 完善並記錄下來。書中對構建 ann 的兩種方法做了詳細解釋,還有在編譯模型時如何選擇損失函式和優化器,以及在不同層啟用函式的選擇。構建人工神經網路的兩種方法如下 第一種方法 model tf.keras.sequential t...

tensorflow 神經網路實現鳶尾花分類

主要步驟 1.準備資料 2.搭建網路 3.引數優化 4.測試效果 import tensorflow as tf from sklearn import datasets from matplotlib import pyplot as plt import numpy as np 匯入輸入特徵和標...