mnist資料集進行自編碼

2022-08-20 15:33:11 字數 4005 閱讀 5355

"""

自動編碼的核心就是各種全連線的組合,它是一種無監督的形式,因為他的標籤是自己。

"""import

torch

import

torch.nn as nn

from torch.autograd import

variable

import

torch.utils.data as data

import

torchvision

import

matplotlib.pyplot as plt

from mpl_toolkits.mplot3d import

axes3d

from matplotlib import

cmimport

numpy as np

#超引數

epoch = 10batch_size = 64lr = 0.005download_mnist =false

n_test_img = 5

#mnist資料集

train_data =torchvision.datasets.mnist(

root='

./mnist/',

train=true,

transform=torchvision.transforms.totensor(),

download=download_mnist,

)print(train_data.train_data.size()) #

(60000, 28, 28)

print(train_data.train_labels.size()) #

(60000)

#顯示出乙個例子

plt.imshow(train_data.train_data[2].numpy(), cmap='

gray')

plt.title('%i

' % train_data.train_labels[2])

plt.show()

#將資料集分為多批資料

train_loader = data.dataloader(dataset=train_data, batch_size=batch_size, shuffle=true)

#搭建自編碼網路框架

class

autoencoder(nn.module):

def__init__

(self):

super(autoencoder, self).

__init__

() self.encoder =nn.sequential(

nn.linear(28*28, 128),

nn.tanh(),

nn.linear(128, 64),

nn.tanh(),

nn.linear(64, 12),

nn.tanh(),

nn.linear(12, 3),

)self.decoder =nn.sequential(

nn.linear(3, 12),

nn.tanh(),

nn.linear(12, 64),

nn.tanh(),

nn.linear(64, 128),

nn.tanh(),

nn.linear(128, 28*28),

nn.sigmoid(),

#將輸出結果壓縮到0到1之間,因為train_data的資料在0到1之間

)

defforward(self, x):

encoded =self.encoder(x)

decoded =self.decoder(encoded)

return

encoded, decoded

autoencoder =autoencoder()

optimizer = torch.optim.adam(autoencoder.parameters(), lr=lr)

loss_func =nn.mseloss()

#initialize figure

f, a = plt.subplots(2, n_test_img, figsize=(5, 2))

plt.ion()

#設定為實時列印

#第一行是原始

view_data = variable(train_data.train_data[:n_test_img].view(-1, 28*28).type(torch.floattensor)/255.)

for i in

range(n_test_img):

a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap='

gray

'); a[0][i].set_xticks(()); a[0][i].set_yticks(())

for epoch in

range(epoch):

for step, (x, y) in

enumerate(train_loader):

b_x = variable(x.view(-1, 28*28))

b_y = variable(x.view(-1, 28*28))

encoded, decoded =autoencoder(b_x)

loss =loss_func(decoded, b_y)

optimizer.zero_grad()

#將上一部的梯度清零

loss.backward() #

反向傳播,計算梯度

optimizer.step() #

優化網路中的各個引數

if step % 100 ==0:

print('

epoch:

', epoch, '

| train loss: %.4f

' %loss.data[0])

#第二行畫出解碼後的

_, decoded_data =autoencoder(view_data)

for i in

range(n_test_img):

a[1][i].clear()

a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)), cmap='

gray')

a[1][i].set_xticks(()); a[1][i].set_yticks(())

plt.draw(); plt.pause(0.05)

plt.ioff()

plt.show()

#視覺化三維圖

view_data = variable(train_data.train_data[:200].view(-1, 28*28).type(torch.floattensor)/255.)

encoded_data, _ =autoencoder(view_data)

fig = plt.figure(2); ax =axes3d(fig)

x, y, z = encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy()

values = train_data.train_labels[:200].numpy()

for x, y, z, s in

zip(x, y, z, values):

c = cm.rainbow(int(255*s/9)); ax.text(x, y, z, s, backgroundcolor=c)

ax.set_xlim(x.min(), x.max()); ax.set_ylim(y.min(), y.max()); ax.set_zlim(z.min(), z.max())

plt.show()

自編碼網路實現Mnist

usr bin python3 coding utf 8 time 2018 3 16 author machuanbin import tensorflow as tf from tensorflow.examples.tutorials.mnist import input data impor...

Mnist手寫數字自編碼 分類實驗

import torch import torch.nn as nn import torch.nn.functional as f import random import numpy as np import matplotlib.pyplot as plt import torchvision...

使用sklearn進行mnist資料集分類

深度之眼 西瓜書課後 import time import matplotlib.pyplot as plt import numpy as np from sklearn.datasets import fetch openml from sklearn.linear model import l...