View on GitHub

pycorrector

pycorrector is a toolkit for text error correction. 文本纠错,Kenlm,Seq2Seq_Attention,BERT,MacBERT,ELECTRA,ERNIE,Transformer等模型实现,开箱即用。

ELECTRA Model

Requirements

使用说明

  1. 下载ELECTRA-base, Chinese Pytorch模型,解压后放置于data/electra_models目录并解压缩即可下。
    electra_models
    └── chinese_electra_base_generator_pytorch
     ├── config.json
     ├── pytorch_model.bin
     └── vocab.txt
    
  2. 运行electra_corrector.py进行纠错。
    python3 electra_corrector.py
    
  3. 评估

todo

简介

ELECTRA提出了一套新的预训练框架,其中包括两个部分:GeneratorDiscriminator

为了进一步促进中文预训练模型技术的研究与发展,哈工大讯飞联合实验室基于官方ELECTRA训练代码以及大规模的中文数据训练出中文ELECTRA预训练模型供大家下载使用。 其中ELECTRA-small模型可与BERT-base甚至其他同等规模的模型相媲美,而参数量仅为BERT-base的1/10。

更详细的内容请查阅ELECTRA论文:ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators

模型下载

本目录中包含以下模型,目前提供PyTorch和TensorFlow版本权重。

模型简称 语料 Google下载 压缩包大小
ELECTRA-base, Chinese 中文维基+通用数据 TensorFlow
PyTorch-D
PyTorch-G
383M
ELECTRA-small, Chinese 中文维基+通用数据 TensorFlow
PyTorch-D
PyTorch-G
46M

以PyTorch版ELECTRA-base, Chinese为例,下载完毕后对zip文件进行解压得到:

chinese_electra_base_discriminator_pytorch
├── config.json                # 模型配置文件
├── pytorch_model.bin          # 模型权重文件
└── vocab.txt                  # 词表

chinese_electra_base_generator_pytorch
├── config.json
├── pytorch_model.bin
└── vocab.txt

训练细节

我们采用了大规模中文维基以及通用文本训练了ELECTRA模型,总token数达到5.4B,与RoBERTa-wwm-ext系列模型一致。词表方面沿用了谷歌原版BERT的WordPiece词表,包含21128个token。其他细节和超参数如下(未提及的参数保持默认):

快速加载

本项目支持ELECTRA模型,可通过如下命令调用。

example: predict_mask.py

import os
from transformers import ElectraForPreTraining, ElectraTokenizer
pwd_path = os.path.abspath(os.path.dirname(__file__))

D_model_dir = os.path.join(pwd_path, "../data/electra_models/chinese_electra_base_discriminator_pytorch/")
tokenizer = ElectraTokenizer.from_pretrained(D_model_dir)
discriminator = ElectraForPreTraining.from_pretrained(D_model_dir)

Feature

  1. 虽然像BERT这样的MASK语言建模(MLM)预训练方法在下游的NLP任务上产生了很好的结果,但是它们需要大量的计算才能有效。 这些方法通过用[MASK]替换一些Token来破坏输入,然后训练一个模型来重构Token。

  2. 作为一种替代方案,我们提出了一种更具效率的预训练,称为Replaced token detection(RTD)判断当前词是否被替换了。 我们的方法不是屏蔽输入,而是用从小型GAN中提取的plausible alternatives sampled替换一些输入Token,从而破坏输入。然后,我们不是训练一个模型来预测[MASK], 而是训练一个判别模型来[MASK]输入中的每个Token是否被生成器样本替换。实验表明,这种预训练模型更有效,因为它从所有的输入中学习,而不是仅仅从[MASK]中。

  3. 在相同的模型大小、数据和计算条件下,通过我们的方法学习的上下文表示大大优于通过BERT和XLNet等方法学习的上下文表示。



本文一个突出贡献就是将GAN引入到预训练语言模型中,并且取得了SOTA(state of the art)的效果。

通过对MASK后的句子生成样本,这里使用的是MLM(maximum likelihood),而不是Adversarially。

通过序列标注的方法,判断当前词是否是原文(original,replaced),GAN生成的都是假的,但是本文的G会生成真实样本(部分词汇真实),梯度不能从D传到G,所以使用强化学习的方法来训练G。

正常情况下,Generator和Discriminator使用相同的大小,但实验表明,Generator更小效果会好点。

Generator和Discriminator只是对于token embeddings进行共享,如果将所有权重共享效果反而会差点。

实验表明,Generator的大小为Discriminator的1/4~1/2效果最好,作者提出,Generator太大会给Discriminator造成困扰。

将Generator和Discriminator进行联合训练,开始只训练通过MLM去训练Generator,然后用Generator的参数去初始化Discriminator,训练Discriminator同时冻结 Generator的参数。

通过对比学习方法来区分虚构的负样本与正样本。

References