import tensorflow as tfimport numpy as np
import os
# you need to change this to your data directory
# train_dir = '/home/kevin/tensorflow/cats_vs_dogs/data/train/'
#存放訓練的路徑train_dir ='/users/arcstone_mems_108/pycharmprojects/catsvsdogs/data/train/'
file_dir: file directory
list of images and labels
'''cats =
label_cats =
dogs =
label_dogs =
#os.listdir為列出路徑內的所有檔案for file in os.listdir(file_dir):
name = file.split(sep='.') #將每乙個檔名都進行分割,以.分割,
#這樣檔名就變為三部分+'/'+file)#name的形式為['dog', '9981', 'jpg']
if name[0]=='cat':
print('there are %d cats\nthere are %d dogs'%(len(cats), len(dogs)))
image_list = np.hstack((cats, dogs)) #將貓和狗的列表合併為乙個列表
label_list = np.hstack((label_cats, label_dogs)) #將貓和狗的標籤列表合併為乙個列表
temp = np.array([image_list, label_list])
temp = temp.transpose() #將陣列矩陣轉置
np.random.shuffle(temp) #將資料打亂順序,不再按照前邊全是貓,後邊全是狗這樣排序
image_list = list(temp[:, 0]) #列表為temp陣列的第乙個元素
label_list = list(temp[:, 1]) #標籤列表為temp陣列的第二個元素
label_list = [int(i) for i in label_list] #轉換為int型別
return image_list, label_list
defget_batch(image, label, image_w, image_h, batch_size, capacity):
image: list type
label: list type
image_w: image width
image_h: image height
batch_size: batch size
capacity: the maximum elements in queue
image_batch: 4d tensor [batch_size, width, height, 3], dtype=tf.float32
label_batch: 1d tensor [batch_size], dtype=tf.int32
image = tf.cast(image, tf.string) #將image資料轉換為string型別
label = tf.cast(label, tf.int32) #將label資料轉換為int型別
# make an input queue
#生成輸入的佇列,每次在資料集中產生乙個切片input_queue = tf.train.slice_input_producer([image, label])
label = input_queue[1]
#的內容為讀取索引為0的位置所得的內容image_contents = tf.read_file(input_queue[0])
#解碼影象,解碼為乙個張量image = tf.image.decode_jpeg(image_contents, channels=3)
# data argumentation should go to here
image = tf.image.resize_image_with_crop_or_pad(image, image_w, image_h)
# if you want to test the generated batches of images, you might want to comment the following line.
# 如果想看到正常的,請注釋掉111行(標準化)和 126行(image_batch = tf.cast(image_batch, tf.float32))
# 訓練時不要注釋掉!
#對影象進行標準化image = tf.image.per_image_standardization(image)
#只是等待出隊image_batch, label_batch = tf.train.batch([image, label],
batch_size= batch_size,
num_threads= 64,
label_batch = tf.reshape(label_batch, [batch_size])image_batch = tf.cast(image_batch, tf.float32)
return image_batch, label_batch
