SRNet

本文是 OCR 的样本生成模型,包括风格迁移模块、前景文字去除模块及融合模块,其中风格迁移模块目地是将文本按照目标风格渲染出来,前景文字去除模块用于去掉目标图片上的文字,融合模块用于融合目标文本和目标背景

什么是 SRNet ?

  • SRNet 是一个利用 GAN 网络完成风格迁移的网络,利用该网络可以将文字按照目标图片的风格渲染到目标图片上
  • 其核心包括 3 部分:(1)文本风格转换模块;(2)背景修复模块;(3)融合模块

SRNet 的网络结构?

  • 文本风格转换模块:完成文本 + 目标图片 => 按照目标图片文字风格的图片,此时图片还没有背景
  • 背景修复模块:目标图片 => 去掉文字的图片,相当于完成目标背景的提取
  • 融合模块:目标风格的文字 + 目标背景 => 文字迁移到新背景上,并且文字风格类似

SRNet 的训练?

  • SRNet 借鉴 GAN 的训练方式,即通过训练 Discriminator-Generator 完成训练
  • Discriminator 训练目地是区分真实样本和生成样本,所以分别在ObO_bOfO_f使用两个生成器训练
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
i_db_true = torch.cat((t_b, i_s), dim = 1)
i_db_pred = torch.cat((o_b, i_s), dim = 1)


i_df_true = torch.cat((t_f, i_t), dim = 1)
i_df_pred = torch.cat((o_f, i_t), dim = 1)

o_db_true = D1(i_db_true)
o_db_pred = D1(i_db_pred)

o_df_true = D2(i_df_true)
o_df_pred = D2(i_df_pred)

db_loss = build_discriminator_loss(o_db_true, o_db_pred)
df_loss = build_discriminator_loss(o_df_true, o_df_pred)
  • Generator 训练目地是使得生成样本更接近真实样本,所以在osko_skoto_tobo_bofo_f四个位置计算损失,并且使用不参与训练的 VGG 计算ofo_ftft_f,目地是计算真实图片和目标图片在多尺度下风格、内容的一致性
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#skeleton loss
l_t_sk = build_dice_loss(t_sk, o_sk)
l_t_l1 = build_l1_loss(t_t, o_t)
l_t = l_t_l1 + l_t_sk

#Background Inpainting module loss
l_b_gan = build_gan_loss(o_db_pred) # gan loss为bce loss
l_b_l1 = build_l1_loss(t_b, o_b)
l_b = l_b_gan + l_b_l1

l_f_gan = build_gan_loss(o_df_pred)
l_f_l1 = build_l1_loss(t_f, o_f)
l_f_vgg_per, l_f_vgg_style = build_vgg_loss(o_vgg) # 分为perceptual loss和style loss
l_f = l_f_gan + l_f_vgg_per + l_f_vgg_style + l_f_l1

loss = cfg.lt * l_t + cfg.lb * l_b + cfg.lf * l_f

参考:

  1. 读论文 ——(Styletext)Editing Text in the Wild-CSDN 博客