CRNN: An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition

传统的文本识别方法需要先对单个文字进行切割,然后再对单个文字进行识别。CRNN直接从整张图出发,将图片特征输入RNN进行序列学习,最后通过CTC计算损失

什么是 CRNN ?

  1. CNN 提取特征:使用轻量化网络 MobileNetv3,其中输入图像的高度统一设置为32,宽度可以为任意长度,经过 CNN 网络后,特征图的高度缩放为1
  2. Im2Seq:将 CNN 获取的特征图变换为 RNN 需要的特征向量序列的形状
  3. 双向 LSTM(BiLSTM)对特征序列进行预测:学习序列中的每个特征向量并输出预测标签分布。这里其实相当于把特征向量的宽度视为 LSTM 中的时间维度
  4. 全连接层分类:使用全连接层对每个序列进行 N+1 类别预测,获取模型的预测结果
  5. CTC:解码模型输出的预测结果,得到最终输出

CRNN 的网络结构?

  • CRNN-20230408144101-1
  • 卷积层,使用 CNN,作用是从输入图像中提取特征序列,要求输入高度必须是 32,以便以上 5 次下采样,可以将其约简到 1 d,如 (512,1,40) 可以认为是 40 个时间步,每个时间步特征向量长度为 512 的,然后才能使用 RNN 学习
  • 循环层,使用RNN,作用是预测从卷积层获取的特征序列的标签(真实值)分布;
  • 转录层,使用CTC,作用是把从循环层获取的标签分布通过去重整合等操作转换成最终的识别结果

CRNN 如何计算 CTC loss?

  • Drawing 2023-03-22 17.23.38.excalidraw
    1. 如何直接对序列进行预测,字符后续后处理时,无法找出连续的字符,在预测的字符之间插入连字符后解决该问题
    1. CTC loss 主要是解决不对齐序列的损失计算,主要原理有 2 个地方:1)加入连字符预测,用于处理连续出现的字符;2)构建所有时间步的状态转移矩阵,根据状态转移矩阵求出得到 gt 预测序列的所有路径,最后1-所有路径概率相加=CTC loss
  • 3.状态转移矩阵的序列转移规则:1)方向只能向下或向右;2)相同字符之间一定有一个空字符;3)非空字符不能被跳过;4)起点必须从前 2 个字符开始;5)终点必须在结尾 2 个连续字符
  • Pytorch 有专门的 API 去求解 CTC loss,直接调用即可
    1
    2
    3
    4
    5
    6
    7
    >>> ctc_loss = nn.CTCLoss()
    >>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_()
    >>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
    >>> input_lengths = torch.full((16,), 50, dtype=torch.long)
    >>> target_lengths = torch.randint(10,30,(16,), dtype=torch.long)
    >>> loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
    >>> loss.backward()

CRNN 的输出如何解析?

  • Drawing 2023-03-22 19.00.35.excalidraw
  • 直接拿到所有时间步的预测结果,然后合并 2 个连字符之间的相同字符,得到最终结果
  • 以上是 12 个时间步的输出结果,去掉连字符内重复字符的过程如: hhe–lll-llo 、 he–l-lo 、hello

参考:

  1. CRNN_AI路上的小白的博客-CSDN博客
  2. CRNN——卷积循环神经网络结构_猛男技术控的博客-CSDN博客