Python创建CRNN训练用的LMDB数据库文件

CRNN简介


CRNN由 Baoguang Shi, Xiang Bai, Cong Yao提出,2015年7月发表论文:“An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition”,链接地址:https://arxiv.org/abs/1507.05717v1


CRNN(卷积循环神经网络)集成了卷积神经网络(CNN)和循环神经网络(RNN)的优点。CRNN可以直接从序列标签(例如单词,句子)中学习,不需要详细的单个分别标注,并且对图像序列对象的长度无限定,只需要在训练和测试阶段对图像高度做一下归一化。于现有技术相比,CRNN在场景文本识别上表现良好。

CRNN中训练数据的格式是LMDB,保存了两种数据,一种是图片数据,一种是标签数据,它们各有其key,如下所示:




准备CRNN训练数据集


数据集图片是若干带有文字的图片,文字的高度约占图片高度的80%~90%,数据集标签是txt文本格式,文本内容是图片上的文字,文本名字要跟图片名字一致,如123.jpg对应标签需要是123.txt。


例如有 01.jpg 和 02.jpg 两个样本,标签文件是 01.txt 和 02.txt :





创建用于CRNN训练的LMDB数据


# -*- coding: utf-8 -*-
import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
#from genLineText import GenTextImage

def checkImageIsValid(imageBin):
    if imageBin is None:
        return False
    imageBuf = np.fromstring(imageBin, dtype=np.uint8)
    img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
    if img is None:
        return False
    imgH, imgW = img.shape[0], img.shape[1]
    if imgH * imgW == 0:
        return False
    return True


def writeCache(env, cache):
    with env.begin(write=True) as txn:
        for k, v in cache.iteritems():
            txn.put(k, v)


def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
    """
    Create LMDB dataset for CRNN training.

    ARGS:
        outputPath    : LMDB output path
        imagePathList : list of image path
        labelList     : list of corresponding groundtruth texts
        lexiconList   : (optional) list of lexicon lists
        checkValid    : if true, check the validity of every image
    """
    #print (len(imagePathList) , len(labelList))
    assert(len(imagePathList) == len(labelList))
    nSamples = len(imagePathList)
    print '...................'
    # map_size=1099511627776 定义最大空间是1TB
    env = lmdb.open(outputPath, map_size=1099511627776)
    
    cache = {}
    cnt = 1
    for i in xrange(nSamples):
        imagePath = imagePathList[i]
        label = labelList[i]
        if not os.path.exists(imagePath):
            print('%s does not exist' % imagePath)
            continue
        with open(imagePath, 'r') as f:
            imageBin = f.read()
        if checkValid:
            if not checkImageIsValid(imageBin):
                print('%s is not a valid image' % imagePath)
                continue


        ########## .mdb数据库文件保存了两种数据,一种是图片数据,一种是标签数据,它们各有其key
        imageKey = 'image-%09d' % cnt
        labelKey = 'label-%09d' % cnt
        cache[imageKey] = imageBin
        cache[labelKey] = label
        ##########
        if lexiconList:
            lexiconKey = 'lexicon-%09d' % cnt
            cache[lexiconKey] = ' '.join(lexiconList[i])
        if cnt % 1000 == 0:
            writeCache(env, cache)
            cache = {}
            print('Written %d / %d' % (cnt, nSamples))
        cnt += 1
    nSamples = cnt-1
    cache['num-samples'] = str(nSamples)
    writeCache(env, cache)
    print('Created dataset with %d samples' % nSamples)


def read_text(path):
    
    with open(path) as f:
        text = f.read()
    text = text.strip()
    
    return text


import glob
if __name__ == '__main__':
    
    #lmdb 输出目录
    outputPath = '../data/lmdb/trainMy'

    # 训练图片路径,标签是txt格式,名字跟图片名字要一致,如123.jpg对应标签需要是123.txt
    path = '../data/dataline/*.jpg'

    imagePathList = glob.glob(path)
    print '------------',len(imagePathList),'------------'
    imgLabelLists = []
    for p in imagePathList:
        try:
           imgLabelLists.append((p,read_text(p.replace('.jpg','.txt'))))
        except:
            continue
            
    #imgLabelList = [ (p,read_text(p.replace('.jpg','.txt'))) for p in imagePathList]
    ##sort by lebelList
    imgLabelList = sorted(imgLabelLists,key = lambda x:len(x[1]))
    imgPaths = [ p[0] for p in imgLabelList]
    txtLists = [ p[1] for p in imgLabelList]
    
    createDataset(outputPath, imgPaths, txtLists, lexiconList=None, checkValid=True)



读取LMDB数据集中图片


# -*- coding: utf-8 -*-
import numpy as np
import lmdb
import cv2

with lmdb.open("../data/lmdb/train") as env:
    txn = env.begin()
    for key, value in txn.cursor():
        print (key,value)
        imageBuf = np.fromstring(value, dtype=np.uint8)
        img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
        if img is not None:
            cv2.imshow('image', img)
            cv2.waitKey()
        else:
            print 'This is a label: {}'.format(value)

原文地址:https://www.cnblogs.com/mtcnn/p/9411755.html