ULMFiT:Universal Language Model Fine-tuning for Text Classification
首先在大数据集上预训练,然后应用到下游的文本分类任务
什么是 ULMFiT ?
- 文本分类模型,但是训练分类模型前,先在大数据集上进行预训练 + 目标数据集上微调,预训练微调的方式均是子监督的,即类似 word 2 vec 的方式
- 微调过程中使用差异学习率、学习率预热等手段
ULMFiT 的网络结构?
- LM pre-training:此阶段借鉴 CV 中的 ImageNet,在通用领域语料库(文中用的是 Wikitext-103,其包含 28,595 篇维基百科文章和 1030 亿个单词)上训练语言模型,从而捕获不同层次文本的一般特征
- LM fine-tuning:此阶段的目标是针对 target domain 上微调 Language Model。在这个阶段,作者用了 Discriminative fine-tuning 和 Slanted triangular learning rates 两个 tricks,以学习不同层次文本在 target domain 上的特征
- Classifier fine-tuning:此阶段主要是训练模型的顶层结构
LM fine-tuning 的两个技巧?
- 差异学习率:由于底层特征更具有通用性,而顶层特征更具有特殊性,所以作者在训练过程中,对于不同层设置了不同的学习率
- 学习率预热:学习率先增后减。先用较小的学习率,得到一个好的优化方向,再用较大的学习率进行优化,在训练后期再使用较小的学习率进行更细致的优化
Classifier fine-tuning 的四个技巧?
- Concat pooling:如果仅使用 RNN 模型最后一个 time step 的输出,显然会丢失信息,尤其是在长文本建模中,因此作者对 RNN 所有 time state 的 hidden states 进行 max pooling 和 mean pooling,然后将 pooling 得到的两个特征与最后一个 time step 的输出连接,作为最终输出
- Gradual unfreezing:直接 fine-tuning 整个网络可能导致网络遗忘之前预训练得到的通用特征,因此作者提出了 gradual unfreezing,具体做法是自顶向下以 epoch 为单位逐步进行 fine-tuning,即第一个 epoch 只解冻最后一层,第二个 epoch 解冻最后两层,以此类推。
- BPTT for Text Classification (BPT3C):为了使大型文档的分类器微调可行,作者将文档划分为大小为 b 的固定长度批次。 在每个批次的开头,用前一批次的最终状态初始化模型,跟踪平均值和最大池的隐藏状态,梯度反向传播到批次
- Bidirectional language model:本文分别训练了前向和后向的 Language Model,在 fine-tuning 阶段对预测的结果取平均
参考: