我们使用 caffe,将图像数据转换为 caffe 可以识别的数据格式是第一步。同时大多数都是通过 python 接口来转换数据格式的。
LMDB 数据库 Caffe 使用 LMDB 的情况大约有两类:
第一类是 DataLayer 层中 使用的 训练集、验证集、测试集;
第二类 就是 ./caffe/build/tools/extract_feature.bin 这种特征提取工具提取特征后,输出的特征文件。
LMDB 的全称是 Lighting Memory-Mapped Database(闪电般的内存映射数据库) 。它文件结构简单,一个文件夹,里面一个数据文件,一个锁文件。数据随意复制,随意传输。它的访问简单,不需要运行单独的数据管理进程。只要在访问的代码里引用 LMDB 库,访问时给文件路径即可。
Caffe 中使用的数据较为很简单,就是大量的矩阵/向量平铺开来。数据之间没有什么关联,数据内没有复杂的对象结构,就是向量和矩阵。既然数据并不复杂,Caffe 就选择了 LMDB 这个简单的数据库来存放数据。
上面提到了,Caffe 使用 LMDB 数据库有两点原因:
一方面是因为数据源的格式多样性,有文本文件、二进制文件图像文件等等,不可能用一个代码完成上述所有的数据格式。因此,通过 LMDB 数据库,转化成统一的数据格式可以简化数据读取层的实现。
第二个方面就是使用 LMDB 数据库可以大大的节约磁盘 IO 的时间开销。因为读取大量小文件的时间开销是相当大的,尤其是在机械硬盘上。 数据库单文件还能减少数据集复制、传输过程的开销。因为我们都有过体会,一个具有几万个、几十万个文件的数据集,不管是直接复制,还是打开再解包,过程都巨慢无比。LMDB 只有一个文件,你的介质有多快,就能复制多快,不会因为文件多而慢的令人心碎。
Caffe 中 Datum 数据结构 Caffe 并不是把向量和矩阵直接放进数据库的,而是将数据通过 caffe.proto 里定义的一个 datum 类来封装的。数据库里存放的是一个个 datum 序列化成的字符串。Datum 的定义如下:
message Datum {
optional int32 channels = 1;
optional int32 height = 2;
optional int32 width = 3;
// the actual image data, in bytes
optional bytes data = 4;
optional int32 label = 5;
// Optionally, the datum could also hold float data.
repeated float float_data = 6;
// If true data contains an encoded image that need to be decoded
optional bool encoded = 7 [default = false];
}
一个 Datum 有三个维度,channnels、height、width,可以看作是少了 num 维度的 Blob。 存放数据的地方有两个:bytes data、float_data,分别存放整数型和浮点型数据。图像数据一般是整形,放在 bytes data 中,特征向量一般是浮点型,存放在 float_data 中。 label 里存放的是类别标签,是整数型。 encoded 标识数据是否需要被解码,因为里面可能存放的是 JPEG 或者 PNG 之类经过编码的数据。
Datum 这个数据结构将数据和标签封装在一起,兼容整形和浮点型数据。经过 protobuf 编译后,可以在 Python 和 C++ 中都提供高效的访问。 同时 protobuf 还为它提供了序列化、反序列化的功能。存放进 LMDB 的就是 Datum 序列化生成的字符串。
Caffe 中将图像写入 LMDB 数据库
我上面解析的 create_mnist_data.cpp 代码对于这部分是很有用的,特别是 LMDB 流程图中的 lmdb 数据操作函数,如打开一个 lmdb 数据库,写入数据等操作,python 中的使用类似,但比 C++ 的要简洁许多 。
下面通过代码来说明吧,这段代码是一个大牛写的教程:《A Practical Introduction to Deep Learning with Caffe and Python》,写的很清晰。
import os
import glob
import random
import numpy as np
import cv2
import caffe
from caffe.proto import caffe_pb2
import lmdb
#Size of images
IMAGE_WIDTH = 227
IMAGE_HEIGHT = 227
# train_lmdb、validation_lmdb 路径
train_lmdb = '/home/chenxp/Documents/vehicleID/val/train_lmdb'
validation_lmdb = '/home/chenxp/Documents/vehicleID/val/validation_lmdb'
# 如果存在了这个文件夹, 先删除
os.system('rm -rf ' + train_lmdb)
os.system('rm -rf ' + validation_lmdb)
# 读取图像
train_data = [img for img in glob.glob("/home/chenxp/Documents/vehicleID/val/query/*jpg")]
test_data = [img for img in glob.glob("/home/chenxp/Documents/vehicleID/val/query/*jpg")]
# Shuffle train_data
# 打乱数据的顺序
random.shuffle(train_data)
# 图像的变换, 直方图均衡化, 以及裁剪到 IMAGE_WIDTH x IMAGE_HEIGHT 的大小
def transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT):
#Histogram Equalization
img[:, :, 0] = cv2.equalizeHist(img[:, :, 0])
img[:, :, 1] = cv2.equalizeHist(img[:, :, 1])
img[:, :, 2] = cv2.equalizeHist(img[:, :, 2])
#Image Resizing, 三次插值
img = cv2.resize(img, (img_width, img_height), interpolation = cv2.INTER_CUBIC)
return img
def make_datum(img, label):
#image is numpy.ndarray format. BGR instead of RGB
return caffe_pb2.Datum(
channels=3,
width=IMAGE_WIDTH,
height=IMAGE_HEIGHT,
label=label,
data=np.rollaxis(img, 2).tobytes()) # or .tostring() if numpy < 1.9
# 打开 lmdb 环境, 生成一个数据文件,定义最大空间, 1e12 = 1000000000000.0
in_db = lmdb.open(train_lmdb, map_size=int(1e12))
with in_db.begin(write=True) as in_txn: # 创建操作数据库句柄
for in_idx, img_path in enumerate(train_data):
if in_idx % 6 == 0: # 只处理 5/6 的数据作为训练集
continue # 留下 1/6 的数据用作验证集
# 读取图像. 做直方图均衡化、裁剪操作
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
img = transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT)
if 'cat' in img_path: # 组织 label, 这里是如果文件名称中有 'cat', 标签就是 0
label = 0 # 如果图像名称中没有 'cat', 有的是 'dog', 标签则为 1
else: # 这里方, label 需要自己去组织
label = 1 # 每次情况可能不一样, 灵活点
datum = make_datum(img, label)
# '{:0>5d}'.format(in_idx):
# lmdb的每一个数据都是由键值对构成的,
# 因此生成一个用递增顺序排列的定长唯一的key
in_txn.put('{:0>5d}'.format(in_idx), datum.SerializeToString()) #调用句柄,写入内存
print '{:0>5d}'.format(in_idx) + ':' + img_path
# 结束后记住释放资源,否则下次用的时候打不开。。。
in_db.close()
# 创建验证集 lmdb 格式文件
print '\nCreating validation_lmdb'
in_db = lmdb.open(validation_lmdb, map_size=int(1e12))
with in_db.begin(write=True) as in_txn:
for in_idx, img_path in enumerate(train_data):
if in_idx % 6 != 0:
continue
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
img = transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT)
if 'cat' in img_path:
label = 0
else:
label = 1
datum = make_datum(img, label)
in_txn.put('{:0>5d}'.format(in_idx), datum.SerializeToString())
print '{:0>5d}'.format(in_idx) + ':' + img_path
in_db.close()
print '\nFinished processing all images'
精简版 这段代码并没有用真实的图像数据来生成,二是用 numpy 中的 np.zeros() 生成了图像格式的数据:
import numpy as np
import lmdb
import caffe
N = 1000
# Let's pretend this is interesting data
X = np.zeros((N, 3, 32, 32), dtype=np.uint8)
y = np.zeros(N, dtype=np.int64)
# We need to prepare the database for the size. We'll set it 10 times
# greater than what we theoretically need. There is little drawback to
# setting this too big. If you still run into problem after raising
# this, you might want to try saving fewer entries in a single
# transaction.
map_size = X.nbytes * 10
env = lmdb.open('mylmdb', map_size=map_size)
with env.begin(write=True) as txn:
# txn is a Transaction object
for i in range(N):
datum = caffe.proto.caffe_pb2.Datum()
datum.channels = X.shape[1]
datum.height = X.shape[2]
datum.width = X.shape[3]
datum.data = X[i].tobytes() # or .tostring() if numpy < 1.9
datum.label = int(y[i])
str_id = '{:08}'.format(i)
# The encode is only essential in Python 3
txn.put(str_id.encode('ascii'), datum.SerializeToString())
运行上一段代码,会生成下面两个文件:
Caffe 从 LMDB 数据库中读取数据
下面就是从生成好的 lmdb 中读取数据了:
import numpy as np
import caffe
import lmdb
import cv2
# 打开 lmdb 数据库, 指定好位置
env = lmdd.open('mylmdb', readonly=True)
with env.begin() as txn:
raw_datum = txn.get(b'00000000')
datum = caffe.proto.caffe_pb2.Datum()
datum.ParseFromString(raw_datum)
flat_x = np.fromstring(datum.data, dtype=np.uint8)
x = flat_x.reshape(datum.channels, datum.height, datum.width)
y = datum.label
print(datum.channels)
print 'label = ' + str(y) # y 为整型, 需要转成字符串
# C x H x W 转换到 H x W x C, 才能在 cv2 中显示
img = cv2.transpose(img, (1, 2, 0)) # 或者: img = x.transpose(1, 2, 0)
cv2.imshow("Image", img)
cv2.waitKey(0)
输出为:
可以迭代读取
with env.open() as txn:
cursor = txn.cursor()
for key, value in cursor:
print(key, value)
下面代码用迭代循环 txn.cursor() 读取:
import caffe
from caffe.proto import caffe_pb2
import lmdb
import cv2
import numpy as np
lmdb_env = lmdb.open('mylmdb', readonly=True) # 打开数据文件
lmdb_txn = lmdb_env.begin() # 生成处理句柄
lmdb_cursor = lmdb_txn.cursor() # 生成迭代器指针
datum = caffe_pb2.Datum() # caffe 定义的数据类型
for key, value in lmdb_cursor: # 循环获取数据
datum.ParseFromString(value) # 从 value 中读取 datum 数据
label = datum.label
data = caffe.io.datum_to_array(datum)
print data.shape
print datum.channels
image = data.transpose(1, 2, 0)
cv2.imshow('cv2.png', image)
cv2.waitKey(0)
cv2.destroyAllWindows()
lmdb_env.close()
本文链接:http://nix.pub/article/lmdb/