文本检测网络Psenet学习(三)

  现有的文本检测方法主要有两大类,一种是基于回归框的检测方法(基于物体检测的方法),如CTPN,EAST,这类方法很难检测任意形状的文本(曲线文本), 一种是基于像素的分割检测器(基于实例分割的方法),这类方法很难将彼此非常接近的文本实例分开。Psenet文本检测方法是基于分割的方法,在2019年的论文Shape Robust Text Detection with Progressive Scale Expansion Network 中提出,优化了近距离文本实例的分离。

  对于Psenet的学习,主要在于四方面:网络结构的设计,kernel的生成,渐进尺度扩展算法(progressive scale expansion),loss函数

1. 网络结构的设计

  Psenet网络采用了resnet+fpn的架构,通过resnet提取特征,取不同层的特征送入fpn进行特征融合,其结构如下图所示:

   上图中给出了训练过程中网络数据流,总结如下:

  1. 1*3*640*640的图片输入网络,经过Resnet网络,将layer1,layer2,layer3,layer4的特征图p1(1*256*160*160), p2(1*512*80*80), p3(1*1024*40*40), p4(1*2048*20*20)送入fpn

  2. 以此对应p1, p2, p3, p4, fpn网络输出特征c1(1*256*160*160), c2(1*256*80*80), c3(1*256*40*40), c4(1*256*20*20)

  3. c2, c3, c4分别上采样2,4,8倍后和c1进行concat得到特征1*1024*160*160,再经过两个卷积输出1*7*160*160,上采样4倍得到网络最终的输出1*7*640*640。

  4.网络最后输出了7个640*640的预测图(map),分别表示预测的text_predict,和6个kernel_predict

  另外,上述采用resnet50的典型结构如下:

  

2. kernel的产生

  上面网络结构中提到模型最后输出7个640*640的预测图, 分别是预测的text,和6个kernel,因此在训练时也需要通过标注数据产生7个640*640的map供网络学习,即text_gt和6个kernel_gt。其中text_gt就是一张二值图,白色部分表示img中含有文字的区域,黑色部分表示背景区域,kernel_gt就是在text_gt的基础上,将白色区域按一定的比例缩小。如下图所示,根据r计算出d,表示该kernel的白色区域边缘部分相对于text_gt的白色区域向内部移动了d个像素。

3. 渐进尺度扩展算法(progressive scale expansion)

  在进行推理时,需要从网络输出的6个kernel中得到需要的box,作者采用了pse(progressive scale exoansion)算法。假设有kernel1,kernel2, kernel3, kernel4, kernel5, kernel6,先从文字区域最小的kernel6开始,遍历其白色区域的像素点,采用广度优先法向四周扩展,依次合并kernel2, kernel3, kernel4, kernel5, kernel6, 最后合并得到一个kernel,整个合并算法看代码比较好理解。取合并后kernel白色区域的矩形框或轮廓线即得到文字检测框。论文中示意图如下:

  参考python代码如下:

import numpy as np
import cv2
# import Queue
from queue import Queue

def pse(kernals, min_area):
    kernal_num = len(kernals)
    pred = np.zeros(kernals[0].shape, dtype='int32')
    
    label_num, label = cv2.connectedComponents(kernals[kernal_num - 1], connectivity=4)
    
    for label_idx in range(1, label_num):
        if np.sum(label == label_idx) < min_area:
            label[label == label_idx] = 0

    queue = Queue.Queue(maxsize = 0)
    next_queue = Queue.Queue(maxsize = 0)
    points = np.array(np.where(label > 0)).transpose((1, 0))
    
    for point_idx in range(points.shape[0]):
        x, y = points[point_idx, 0], points[point_idx, 1]
        l = label[x, y]
        queue.put((x, y, l))
        pred[x, y] = l

    dx = [-1, 1, 0, 0]
    dy = [0, 0, -1, 1]
    for kernal_idx in range(kernal_num - 2, -1, -1):
        kernal = kernals[kernal_idx].copy()
        while not queue.empty():
            (x, y, l) = queue.get()

            is_edge = True
            for j in range(4):
                tmpx = x + dx[j]
                tmpy = y + dy[j]
                if tmpx < 0 or tmpx >= kernal.shape[0] or tmpy < 0 or tmpy >= kernal.shape[1]:
                    continue
                if kernal[tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0:
                    continue

                queue.put((tmpx, tmpy, l))
                pred[tmpx, tmpy] = l
                is_edge = False
            if is_edge:
                next_queue.put((x, y, l))
        
        # kernal[pred > 0] = 0
        queue, next_queue = next_queue, queue
        
        # points = np.array(np.where(pred > 0)).transpose((1, 0))
        # for point_idx in range(points.shape[0]):
        #     x, y = points[point_idx, 0], points[point_idx, 1]
        #     l = pred[x, y]
        #     queue.put((x, y, l))

    return pred
pse算法

4. loss函数理解

  psenet的loss包括两部分,gt_text和kernel的loss,都采用dice loss计算损失值。总的loss计算如公司如下,权重系数一般取λ=0.7

  dice loss的计算公式如下,参见代码比较好理解

  dice loss 参考代码:

def dice_loss(input, target, mask):
    #input为预测的map
    #target为标注的map
    input = torch.sigmoid(input)

    input = input.contiguous().view(input.size()[0], -1)
    target = target.contiguous().view(target.size()[0], -1)
    mask = mask.contiguous().view(mask.size()[0], -1)

    input = input * mask
    target = target * mask

    a = torch.sum(input * target, 1)
    b = torch.sum(input * input, 1) + 0.001
    c = torch.sum(target * target, 1) + 0.001
    d = (2 * a) / (b + c)
    dice_loss = torch.mean(d)
    return 1 - dice_loss
dice loss示意代码

 参考:

  https://github.com/whai362/PSENet

  https://github.com/WenmuZhou/PSENet.pytorch

原文地址:https://www.cnblogs.com/silence-cho/p/14151233.html