[数据集]cifar-10

cifar-10数据集保存10类,每类6000张图像。其中50000张训练图像和10000张测试图像

训练图像保存在5个文件中,每个文件有10000张图像,测试图像保存在一个文件,训练和测试图像都以随机顺序保存

官网:The CIFAR-10 dataset

cifar-10提供了使用不同语言生成的压缩包,包括python/matlab/c

python解析

训练文件命名为data_batch_1/data_batch_2/data_batch_3/data_batch_4/data_batch_5

测试文件命名为test_batch

另外还有一个元数据文件batches.meta,保存了标签名对应的类名,比如label_names[0] == "airplane", label_names[1] == "automobile"

python版压缩包使用pickle模块进行保存,解析程序如下,返回一个dict

1
2
3
4
5
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict

以测试集文件为例,其包含以下键值对

1
dict_keys([b'batch_label', b'labels', b'data', b'filenames'])

其中b'labels'保存类别标签(0-9),b'data'保存对应图像数据,b'filenames'保存图像文件名

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
if __name__ == '__main__':
data_dir = '/home/zj/data/cifar-10-batches-py/'
test_data_dir = data_dir + 'test_batch'
di = unpickle(test_data_dir)
# 打印所有键值
print(di.keys())

batch_label = di.get(b'batch_label')
filenames = di.get(b'filenames')
labels = di.get(b'labels')
data = di.get(b'data')

print(batch_label)
print(filenames[0])
print(labels[0])
print(data[0])

dict_keys([b'batch_label', b'labels', b'data', b'filenames'])
b'testing batch 1 of 1'
b'domestic_cat_s_000907.png'
3
[158 159 165 ... 124 129 110]

数据格式

b'data'保存了图像数据,其值类型是numpy.ndarray,每一行都保存了32x32大小图像数据

其中前1024个字节是红色分量值,中间1024个字节是绿色分量值,最后1024个字节是蓝色分量值

解析程序如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# -*- coding: utf-8 -*-

import numpy as np
import cv2 as cv


def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict


def convert_data_to_img(data):
"""
转换数据为图像
:param data: 3072维
:return: img
"""
num = 1024
red = data[:num]
green = data[num:(num * 2)]
blue = data[(num * 2):]

img = np.ndarray((32, 32, 3))
for i in range(32):
for j in range(32):
img[i, j, 0] = blue[i * 32 + j]
img[i, j, 1] = green[i * 32 + j]
img[i, j, 2] = red[i * 32 + j]

return img


if __name__ == '__main__':
data_dir = '/home/zj/data/cifar-10-batches-py/'
test_data_dir = data_dir + 'test_batch'
di = unpickle(test_data_dir)

filenames = di.get(b'filenames')
data = di.get(b'data')

filename = str(filenames[0], encoding='utf-8')
img = convert_data_to_img(data[0])

cv.imwrite(filename, img)

解压代码

完整解压代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# -*- coding: utf-8 -*-

# @Time : 19-6-5 下午8:20
# @Author : zj

import numpy as np
import pickle
import os
import cv2

data_list = ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5', 'test_batch']

res_data_dir = '/home/zj/data/decompress_cifar_10'


def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict


def write_img(data, labels, filenames, isTrain=True):
if isTrain:
data_dir = os.path.join(res_data_dir, 'train')
else:
data_dir = os.path.join(res_data_dir, 'test')

if not os.path.exists(data_dir):
os.mkdir(data_dir)

N = len(labels)
for i in range(N):
cate_dir = os.path.join(data_dir, str(labels[i]))
if not os.path.exists(cate_dir):
os.mkdir(cate_dir)
img_path = os.path.join(cate_dir, str(filenames[i], encoding='utf-8'))

# r = data[i][:1024].reshape(32, 32)
# g = data[i][1024:2048].reshape(32, 32)
# b = data[i][2048:].reshape(32, 32)
# img = cv2.merge((b, g, r))

img = data[i].reshape(3, 32, 32)
img = np.transpose(img, (1, 2, 0))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imwrite(img_path, img)


if __name__ == '__main__':

for item in data_list:
data_dir = '/home/zj/data/cifar-10-batches-py/'
data_dir = os.path.join(data_dir, item)
di = unpickle(data_dir)

batch_label = str(di.get(b'batch_label'), encoding='utf-8')
filenames = di.get(b'filenames')
labels = di.get(b'labels')
data = di.get(b'data')

if 'train' in batch_label:
write_img(data, labels, filenames)
else:
write_img(data, labels, filenames, isTrain=False)

读取图像

读取解压后的图像并显示,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
if __name__ == '__main__':
data_dir = '/home/zj/data/cifar-10-batches-py/test_batch'

di = unpickle(data_dir)

batch_label = str(di.get(b'batch_label'), encoding='utf-8')
filenames = di.get(b'filenames')
labels = di.get(b'labels')
data = di.get(b'data')

N = 10
W = 32
H = 32
ex = np.zeros((H * N, W * N, 3))
for i in range(N):
for j in range(N):
img = data[i * N + j].reshape(3, H, W)
img = np.transpose(img, (1, 2, 0))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

ex[i * H:(i + 1) * H, j * W:(j + 1) * W] = img
plt.imshow(ex.astype(np.uint8), cmap='gray')
plt.axis('off')
plt.show()

坚持原创技术分享,您的支持将鼓励我继续创作!