我们使用 caffe,将图像数据转换为 caffe 可以识别的数据格式是第一步。同时大多数都是通过 python 接口来转换数据格式的。

LMDB 数据库 Caffe 使用 LMDB 的情况大约有两类:

第一类是 DataLayer 层中 使用的 训练集、验证集、测试集;
第二类 就是 ./caffe/build/tools/extract_feature.bin 这种特征提取工具提取特征后,输出的特征文件。 LMDB 的全称是 Lighting Memory-Mapped Database(闪电般的内存映射数据库) 。它文件结构简单,一个文件夹,里面一个数据文件,一个锁文件。数据随意复制,随意传输。它的访问简单,不需要运行单独的数据管理进程。只要在访问的代码里引用 LMDB 库,访问时给文件路径即可。

Caffe 中使用的数据较为很简单,就是大量的矩阵/向量平铺开来。数据之间没有什么关联,数据内没有复杂的对象结构,就是向量和矩阵。既然数据并不复杂,Caffe 就选择了 LMDB 这个简单的数据库来存放数据。

上面提到了,Caffe 使用 LMDB 数据库有两点原因:

一方面是因为数据源的格式多样性,有文本文件、二进制文件图像文件等等,不可能用一个代码完成上述所有的数据格式。因此,通过 LMDB 数据库,转化成统一的数据格式可以简化数据读取层的实现。

第二个方面就是使用 LMDB 数据库可以大大的节约磁盘 IO 的时间开销。因为读取大量小文件的时间开销是相当大的,尤其是在机械硬盘上。 数据库单文件还能减少数据集复制、传输过程的开销。因为我们都有过体会,一个具有几万个、几十万个文件的数据集,不管是直接复制,还是打开再解包,过程都巨慢无比。LMDB 只有一个文件,你的介质有多快,就能复制多快,不会因为文件多而慢的令人心碎。

Caffe 中 Datum 数据结构 Caffe 并不是把向量和矩阵直接放进数据库的,而是将数据通过 caffe.proto 里定义的一个 datum 类来封装的。数据库里存放的是一个个 datum 序列化成的字符串。Datum 的定义如下:

message Datum {
  optional int32 channels = 1;
  optional int32 height = 2;
  optional int32 width = 3;
  // the actual image data, in bytes
  optional bytes data = 4;
  optional int32 label = 5;
  // Optionally, the datum could also hold float data.
  repeated float float_data = 6;
  // If true data contains an encoded image that need to be decoded
  optional bool encoded = 7 [default = false];
}

一个 Datum 有三个维度,channnels、height、width,可以看作是少了 num 维度的 Blob。 存放数据的地方有两个:bytes data、float_data,分别存放整数型和浮点型数据。图像数据一般是整形,放在 bytes data 中,特征向量一般是浮点型,存放在 float_data 中。 label 里存放的是类别标签,是整数型。 encoded 标识数据是否需要被解码,因为里面可能存放的是 JPEG 或者 PNG 之类经过编码的数据。

Datum 这个数据结构将数据和标签封装在一起,兼容整形和浮点型数据。经过 protobuf 编译后,可以在 Python 和 C++ 中都提供高效的访问。 同时 protobuf 还为它提供了序列化、反序列化的功能。存放进 LMDB 的就是 Datum 序列化生成的字符串。

Caffe 中将图像写入 LMDB 数据库

我上面解析的 create_mnist_data.cpp 代码对于这部分是很有用的,特别是 LMDB 流程图中的 lmdb 数据操作函数,如打开一个 lmdb 数据库,写入数据等操作,python 中的使用类似,但比 C++ 的要简洁许多 。

下面通过代码来说明吧,这段代码是一个大牛写的教程:《A Practical Introduction to Deep Learning with Caffe and Python》,写的很清晰。

import os
import glob
import random
import numpy as np

import cv2

import caffe
from caffe.proto import caffe_pb2
import lmdb

#Size of images
IMAGE_WIDTH = 227
IMAGE_HEIGHT = 227

# train_lmdb、validation_lmdb 路径
train_lmdb = '/home/chenxp/Documents/vehicleID/val/train_lmdb'
validation_lmdb = '/home/chenxp/Documents/vehicleID/val/validation_lmdb'

# 如果存在了这个文件夹, 先删除
os.system('rm -rf  ' + train_lmdb)
os.system('rm -rf  ' + validation_lmdb)

# 读取图像
train_data = [img for img in glob.glob("/home/chenxp/Documents/vehicleID/val/query/*jpg")]
test_data = [img for img in glob.glob("/home/chenxp/Documents/vehicleID/val/query/*jpg")]

# Shuffle train_data
# 打乱数据的顺序
random.shuffle(train_data)

# 图像的变换, 直方图均衡化, 以及裁剪到 IMAGE_WIDTH x IMAGE_HEIGHT 的大小
def transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT):
    #Histogram Equalization
    img[:, :, 0] = cv2.equalizeHist(img[:, :, 0])
    img[:, :, 1] = cv2.equalizeHist(img[:, :, 1])
    img[:, :, 2] = cv2.equalizeHist(img[:, :, 2])

    #Image Resizing, 三次插值
    img = cv2.resize(img, (img_width, img_height), interpolation = cv2.INTER_CUBIC)
    return img

def make_datum(img, label):
    #image is numpy.ndarray format. BGR instead of RGB
    return caffe_pb2.Datum(
        channels=3,
        width=IMAGE_WIDTH,
        height=IMAGE_HEIGHT,
        label=label,
        data=np.rollaxis(img, 2).tobytes()) # or .tostring() if numpy < 1.9

# 打开 lmdb 环境, 生成一个数据文件,定义最大空间, 1e12 = 1000000000000.0
in_db = lmdb.open(train_lmdb, map_size=int(1e12)) 
with in_db.begin(write=True) as in_txn: # 创建操作数据库句柄
    for in_idx, img_path in enumerate(train_data):
        if in_idx %  6 == 0: # 只处理 5/6 的数据作为训练集
            continue         # 留下 1/6 的数据用作验证集
        # 读取图像. 做直方图均衡化、裁剪操作
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        img = transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT)

        if 'cat' in img_path: # 组织 label, 这里是如果文件名称中有 'cat', 标签就是 0
            label = 0         # 如果图像名称中没有 'cat', 有的是 'dog', 标签则为 1
        else:                 # 这里方, label 需要自己去组织
            label = 1         # 每次情况可能不一样, 灵活点

        datum = make_datum(img, label)
        # '{:0>5d}'.format(in_idx):
        #      lmdb的每一个数据都是由键值对构成的,
        #      因此生成一个用递增顺序排列的定长唯一的key
        in_txn.put('{:0>5d}'.format(in_idx), datum.SerializeToString()) #调用句柄,写入内存
        print '{:0>5d}'.format(in_idx) + ':' + img_path

# 结束后记住释放资源,否则下次用的时候打不开。。。
in_db.close() 

# 创建验证集 lmdb 格式文件
print '\nCreating validation_lmdb'
in_db = lmdb.open(validation_lmdb, map_size=int(1e12))
with in_db.begin(write=True) as in_txn:
    for in_idx, img_path in enumerate(train_data):
        if in_idx % 6 != 0:
            continue
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        img = transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT)
        if 'cat' in img_path:
            label = 0
        else:
            label = 1
        datum = make_datum(img, label)
        in_txn.put('{:0>5d}'.format(in_idx), datum.SerializeToString())
        print '{:0>5d}'.format(in_idx) + ':' + img_path
in_db.close()
print '\nFinished processing all images'

精简版 这段代码并没有用真实的图像数据来生成,二是用 numpy 中的 np.zeros() 生成了图像格式的数据:

import numpy as np
import lmdb
import caffe

N = 1000

# Let's pretend this is interesting data
X = np.zeros((N, 3, 32, 32), dtype=np.uint8)
y = np.zeros(N, dtype=np.int64)

# We need to prepare the database for the size. We'll set it 10 times
# greater than what we theoretically need. There is little drawback to
# setting this too big. If you still run into problem after raising
# this, you might want to try saving fewer entries in a single
# transaction.
map_size = X.nbytes * 10

env = lmdb.open('mylmdb', map_size=map_size)

with env.begin(write=True) as txn:
    # txn is a Transaction object
    for i in range(N):
        datum = caffe.proto.caffe_pb2.Datum()
        datum.channels = X.shape[1]
        datum.height = X.shape[2]
        datum.width = X.shape[3]
        datum.data = X[i].tobytes()  # or .tostring() if numpy < 1.9
        datum.label = int(y[i])
        str_id = '{:08}'.format(i)

        # The encode is only essential in Python 3
        txn.put(str_id.encode('ascii'), datum.SerializeToString())

运行上一段代码,会生成下面两个文件:

lmdb1

Caffe 从 LMDB 数据库中读取数据

下面就是从生成好的 lmdb 中读取数据了:

import numpy as np
import caffe
import lmdb
import cv2

# 打开 lmdb 数据库, 指定好位置
env = lmdd.open('mylmdb', readonly=True)
with env.begin() as txn:
    raw_datum = txn.get(b'00000000')

datum = caffe.proto.caffe_pb2.Datum()
datum.ParseFromString(raw_datum)

flat_x = np.fromstring(datum.data, dtype=np.uint8)
x = flat_x.reshape(datum.channels, datum.height, datum.width)
y = datum.label

print(datum.channels)
print 'label = ' + str(y) # y 为整型, 需要转成字符串

# C x H x W 转换到 H x W x C, 才能在 cv2 中显示
img = cv2.transpose(img, (1, 2, 0)) # 或者: img = x.transpose(1, 2, 0)
cv2.imshow("Image", img)
cv2.waitKey(0)

输出为: lmdb2

可以迭代读取

with env.open() as txn:
    cursor = txn.cursor()
    for key, value in cursor:
        print(key, value)

下面代码用迭代循环 txn.cursor() 读取:

import caffe
from caffe.proto import caffe_pb2

import lmdb
import cv2
import numpy as np

lmdb_env = lmdb.open('mylmdb', readonly=True) # 打开数据文件
lmdb_txn = lmdb_env.begin() # 生成处理句柄
lmdb_cursor = lmdb_txn.cursor() # 生成迭代器指针
datum = caffe_pb2.Datum() # caffe 定义的数据类型

for key, value in lmdb_cursor: # 循环获取数据
    datum.ParseFromString(value) # 从 value 中读取 datum 数据

    label = datum.label
    data = caffe.io.datum_to_array(datum)
    print data.shape
    print datum.channels
    image = data.transpose(1, 2, 0)
    cv2.imshow('cv2.png', image)
    cv2.waitKey(0)

cv2.destroyAllWindows()
lmdb_env.close()

本文链接:http://nix.pub/article/lmdb/