论文阅读——CLIP算法
论文阅读——CLIP算法
原文链接:[2103.00020] Learning Transferable Visual Models From Natural Language Supervision (arxiv.org)
1、算法原理
CLIP(Contrastive Language-Image Pre-training)具备很强的迁移学习能力。在无任意一张ImageNet图片训练情景下,直接进行Zero-shot推理,就能媲美监督训练下的ResNet-50模型的结果。
宏观来看CLIP分为三部分:
Contrastive pre-training:预训练阶段,使用图片 - 文本对进行对比学习训练;
Create dataset classifier from label text:提取预测类别文本特征;
Use for zero-shot predictiion:进行 Zero-Shoot 推理预测;
第一阶段,图像和文本分别通过图像、文本编码器生成对应的\(l_1、l_2…l_n\)、\(T_1、T_2…T_n\)的特征向量,计算对应角标向量的余弦相似度,通过temperature参数缩放,并借助softmax归一化为概率分布。图像编码器选用两个架构,第一个采用的是ResNet-50的基础架构,使用ResNetD和Rect-2进行改进,将全局平均池化层替换为一个单层的注意力池化机制;第二个采用改进的ViT模型。文本编码器使用的是一个Transformer编码器,有8个注意力头,使用了隐藏的自注意。
第二阶段,使用提示模板,帮助指定文本是否是关于图像的内容。将输出的句子通过文本编码器进行特征提取,得到特征向量。
第三阶段,输入一张图片,经过图像编码器进行特征提取生成一个特征向量,与文本特征进行余弦相似度计算,最相似的即为预测结果。
2、代码实现
伪代码如下:
CLIP模型前向传播部分:
1 |
|