论文阅读——CLIP算法

论文阅读——CLIP算法

原文链接:[2103.00020] Learning Transferable Visual Models From Natural Language Supervision (arxiv.org)

代码链接:openai/CLIP: CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image (github.com)

1、算法原理

CLIP(Contrastive Language-Image Pre-training)具备很强的迁移学习能力。在无任意一张ImageNet图片训练情景下,直接进行Zero-shot推理,就能媲美监督训练下的ResNet-50模型的结果。

宏观来看CLIP分为三部分:

  • Contrastive pre-training:预训练阶段,使用图片 - 文本对进行对比学习训练;

  • Create dataset classifier from label text:提取预测类别文本特征;

  • Use for zero-shot predictiion:进行 Zero-Shoot 推理预测;

image-20230917172037149

第一阶段,图像和文本分别通过图像、文本编码器生成对应的\(l_1、l_2…l_n\)\(T_1、T_2…T_n\)的特征向量,计算对应角标向量的余弦相似度,通过temperature参数缩放,并借助softmax归一化为概率分布。图像编码器选用两个架构,第一个采用的是ResNet-50的基础架构,使用ResNetD和Rect-2进行改进,将全局平均池化层替换为一个单层的注意力池化机制;第二个采用改进的ViT模型。文本编码器使用的是一个Transformer编码器,有8个注意力头,使用了隐藏的自注意。

第二阶段,使用提示模板,帮助指定文本是否是关于图像的内容。将输出的句子通过文本编码器进行特征提取,得到特征向量。

第三阶段,输入一张图片,经过图像编码器进行特征提取生成一个特征向量,与文本特征进行余弦相似度计算,最相似的即为预测结果。

2、代码实现

伪代码如下:

image-20230918110417534

CLIP模型前向传播部分:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def forward(self, image, text):
image_features = self.encode_image(image) # 图片编码提特征
text_features = self.encode_text(text) # 文本编码提特征

# 特征归一化
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)

# 计算余弦相似度
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()

# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text

论文阅读——CLIP算法
http://example.com/2023/09/18/clip/
作者
Z Z
发布于
2023年9月18日
许可协议