文本聚类之Supporting Clustering with Contrastive Learning(SCCL)论文复现【代码纠错及完善】
文本聚类之Supporting Clustering with Contrastive Learning(SCCL)论文复现
前言
近期想探索一些自己研究方向相关的领域,因而有了这篇文章,从图像的聚类到文本的聚类。本文的文本聚类更确切地说是短文本聚类(short text clustering),这是一个并不火热的方向,可以看到这篇论文是两年前发表的,SCCL是一篇很不错的工作。文本数据这块比较火热的是基于自监督、对比学习的NLP,聚类常被作为一个下游子任务,文本数据方面的主要任务还是学习文本的语义表征。不过,这个方向也未尝不是一个值得探索的方向(有着训练时间短,对显卡资源要求低的优点,算法稳定性强),没准哪天也像对比学习那样迎来一个春天。
一、论文介绍
论文名称:Supporting Clustering with Contrastive Learning,发表在NAACL会议上,在NLP领域也算是顶会吧。
论文链接:https://arxiv.org/abs/2103.12953
github地址:https://github.com/amazon-science/sccl
摘要:无监督聚类的目的是根据在表征(representation)中度量的距离发现数据的语义类别空间。然而,在学习过程的开始阶段,不同的类别往往在表征空间中相互重叠,这对基于距离的聚类在实现不同类别之间的良好分离提出了重大挑战。为此,本文提出了基于对比学习的支持聚类(SCCL)——一个利用对比学习来促进更好分离的新框架。并评估了SCCL在短文本聚类(short text clustering)上的性能,表明SCCL在大多数基准数据集上显著提高了达到SOTA的结果,准确率(ACC)提高了3% ~ 11%,规范化互信息(NMI)提高了4% ~ 15%。此外,定量分析证明了SCCL在利用自下而上的实例判别(instance discrimination)和自上而下的聚类优势方面的有效性,从而在使用ground truth聚类标签进行评估时获得更好的聚类内和聚类间距离。
二、复现步骤
本人复现代码已发布到github:https://github.com/Regan-Zhang/SCCL-EXEC
复现过程可参考sccl的github项目中的README.md文档,一般没有什么问题的话即可完成本节如下的第1步数据集下载及处理、第2步数据预处理、第3步依赖库的安装。
1. 数据集下载及处理
上图是SCCL原论文对数据集的数据集情况表一览。基于运行时间考虑,选取数据量适中的AgNews文本数据,余下部分本人也以AgNews为例进行讲解。
⋅
\huge ·
⋅ 进入数据集网址:
https://github.com/rashadulrakib/short-text-clustering-enhancement/tree/master/data
然后点击红框处下载就可以啦。这些下载下来的都是原始文本数据,并没有经过处理和数据增强。
⋅
\huge ·
⋅ 下载下来后文件名为agnewsdataraw-8000.txt,将其后缀名改成csv,当然这里我是改成了agnewsdataraw.csv,命名看读者喜好可以自己定义但是后缀名一定要为csv!
⋅
\huge ·
⋅ 用WPS或者Excel打开csv文件,不出意外的话只有一列数据,label标签和text文本数据糊成一片。不过不要紧,这是因为直接打开csv文件而没有通过源导入定义格式导致的。只要重新导入文件即可。导入过程都在下面(其实也可以新建一个空的csv文件再导入源):
如果不新建的话在原来csv文件操作就需要把光标放到第一个单元。
⋅
\huge ·
⋅ 点击“数据”–>“导入数据”–>“导入数据”–>选择“直接打开数据文件”–>“选择数据源”,选中我们的agnewsdataraw.csv进行导入。
⋅
\huge ·
⋅ 选择编码,没有尝试过其他,最常见编码是GBK和UTF-8,此处我选择UTF-8,这样不容易出错。
⋅
\huge ·
⋅ 选择分隔符号,下一步。
⋅
\huge ·
⋅ 选择Tab键,这步很关键,源数据就是通过Tab分隔的。
⋅
\huge ·
⋅ 接着就是下一步然后完成。得到如下结果就证明导入成功了:
⋅
\huge ·
⋅ 可以看到原来光标的位置多出两行,然后我们把原先的那一列删掉。
在首行补充label和text的列名称:
最后保存,数据集处理就大功告成啦!
2. 数据预处理
文本数据预处理,按文档要求对文本数据集进行数据增强(data augmentation)。
数据增强的代码位于./Augdata目录下,修改路径和数据集名称就能运行啦。细节可以看我发布的源码【之后不久会发布到我个人github】。可以看到我导出源重新建了个文件夹AgNews。
3. 安装依赖库
pytorch==1.6.0.
sentence-transformers==2.0.0.
transformers==4.8.1.
tensorboardX==2.4.1
pandas==1.1.5
sklearn==0.24.1
numpy==1.19.5
这是源项目提供的依赖库,不过其中还少了一个很关键的库:nlpaug,这个库看缩写就知道是用于NLP领域的数据增强,pip install 直接安装即可。
pip install nlpaug
或许每个人运行代码后都会出现各种各样问题,但就我个人情况是遇到这个问题(如下图):proto协议版本过高,需要降低这个库的版本。先uninstall然后指定版本进行install即可。
pip uninstall protobuf
pip install protobuf==3.19.0
按照报错提示,安装3.19.0这个版本。实际上不管遇到什么问题,看懂报错提示就可以了,然后搜索对应解决方法,兵来将挡水来土掩。
比如上图所示的报错,需要更新sentence transformer这个库为最新版本。
pip install -U sentence-transformers
3. 代码修改
不知道是作者故意的,还是笔误,对应位置改成distilbert就可以了。
上图这个越界的错误是整个SCCL项目中最大的错误,但修改起来也很容易,只需要加一句代码。
报错的原因是agnewsdataraw.csv中的标签是从1开始打标签,的agnews有4个类,标签为
[
1
,
2
,
3
,
4
]
[1,2,3,4]
[1,2,3,4],而编程的世界里编号都是从零开始的,标签应为
[
0
,
1
,
2
,
3
]
[0,1,2,3]
[0,1,2,3],因此找到对应位置将 label值
−
1
-1
−1就行啦(如图)。
值得一提的是,由于代码运行时会自动从github下载BERT的预训练权重,所以需要科学上网,这个请读者自行了解解决。
4. 运行结果
复现在Nvidia GeForce RTX3090(显存24G)上运行,遵照原项目参数配置,运行起来后占显存为15G+,还是对显卡蛮友好的,而且运行速度快以至于都不用保存权重,运行时间大概在40分钟左右。如果读者显卡资源没有这么充足,可以减小batch size等超参去适应显卡的配置。
上图是论文中汇报的结果,而我针对AgNews数据集的复现的效果如下,使用两种数据增强(charswap字符交换和word deletion词删除):
AgNews | NMI | ACC |
---|---|---|
charswap | 66.9 | 87.3 |
word deletion | 67.5 | 87.5 |
最好的结果(即word deletion这一行)和论文里汇报相差一个点,还是蛮接近的,由此也证明模型和论文的真实性和可行性。