SAM论文细读

  • 论文地址:https://arxiv.org/abs/2304.02643

    项目地址:Segment Anything

    项目演示:https://segment-anything.com/demo

    基本思想:

    1.零次学习:

    img

    下图所示:

    利用过去的知识(马,老虎,熊猫和斑马的描述),在脑海中推理出新对象的具体形态,从而能对新对象进行辨认。就是希望能够模仿人类的这个推理过程,使得计算机具有识别新事物的能力。

    img

    2.结构分析

    img

    SAM模型大致上分成3个模块,一个标准的vit构成的image encoder、一个prompt encoder和一个mask decoder。其中:

    • Image encoder: 用于输出image embedding;
    • prompt encoder:用于接收point、box、txt的编码信息,并且与image embedding组合到一起送入mask decoder中;
    • mask decoder:将上述两个encoder的编码信息转化为mask输出。

    一、image encoder

    (1)、复习一下

    先复习一下关于vit的东西。

    img

    img

    模型由三个模块组成:

    • Linear Projection of Flattened Patches(Embedding层)
    • Transformer Encoder
    • MLP Head(最终用于分类的层结构)

    工作流程如下:

    ViT的工作流程,如下:

    1. 将一张图片分成patches
    2. 将patches铺平
    3. 将铺平后的patches的线性映射到更低维的空间
    4. 添加位置embedding编码信息
    5. 将图像序列数据送入标准Transformer encoder中去
    6. 在较大的数据集上预训练
    7. 在下游数据集上微调用于图像分类

    步骤1、将图片转换成patches序列

    这一步很关键,为了让Transformer能够处理图像数据,第一步必须先将图像数据转换成序列数据,但是怎么做呢?假如我们有一张图片,patch大小为,那么我们可以创建个图像patches,可以表示为,其中,就是序列的长度,类似一个句子中单词的个数。在上面的图中,可以看到图片被分为了9个patches。

    步骤2、将Patches铺平

    在原论文中,作者选用的patch大小为16,那么一个patch的shape为(3,16,16),维度为3,将它铺平之后大小为3x16x16=768。即一个patch变为长度为768的向量。不过这看起来还是有点大,此时可以使用加一个Linear transformation,即添加一个线性映射层,将patch的维度映射到我们指定的embedding的维度,这样就和NLP中的词向量类似了。

    步骤3、添加Position embedding

    与CNNs不同,此时模型并不知道序列数据中的patches的位置信息。所以这些patches必须先追加一个位置信息,也就是图中的带数字的向量。实验表明,不同的位置编码embedding对最终的结果影响不大,在Transformer原论文中使用的是固定位置编码,在ViT中使用的可学习的位置embedding 向量,将它们加到对应的输出patch embeddings上。

    步骤4、添加class token

    在输入到Transformer Encoder之前,还需要添加一个特殊的class token,这一点主要是借鉴了BERT模型。添加这个class token的目的是因为,ViT模型将这个class token在Transformer Encoder的输出当做是模型对输入图片的编码特征,用于后续输入MLP模块中与图片label进行loss计算。

    步骤5、输入Transformer Encoder

    将patch embedding和class token拼接起来输入标准的Transformer Encoder中。 Transformer Encoder其实就是重复堆叠Encoder Block L次,主要由Layer Norm、Multi-Head Attention、Dropout和MLP Block几部分组成。

    步骤6、分类

    注意Transformer Encoder的输出其实也是一个序列,但是在ViT模型中只使用了class token的输出,将其送入MLP模块中,去输出最终的分类结果。

    Embedding层

    对于标准的Transformer模块,要求输入的是token(向量)序列,即二维矩阵[num_token, token_dim],如下图,token0-9对应的都是向量,以ViT-B/16为例,每个token向量长度为768。

    img

    对于图像数据而言,其数据格式为[H, W, C]是三维矩阵明显不是Transformer想要的。所以需要先通过一个Embedding层来对数据做个变换。首先将一张图片按给定大小分成一堆Patches。以ViT-B/16为例,将输入图片(224x224)按照16x16大小的Patch进行划分,划分后会得到196个Patches。接着通过线性映射将每个Patch映射到一维向量中,以ViT-B/16为例,每个Patche数据shape为[16, 16, 3]通过映射得到一个长度为768的向量(后面都直接称为token)。[16, 16, 3] -> [768]

    在代码实现中,直接通过一个卷积层来实现。 以ViT-B/16为例,直接使用一个卷积核大小为16x16,步距为16,卷积核个数为768的卷积来实现。通过卷积[224, 224, 3] -> [14, 14, 768],然后把H以及W两个维度展平即可[14, 14, 768] -> [196, 768],此时正好变成了一个二维矩阵,正是Transformer想要的。

    img

    这里面就是把以前基于NLP的tansformer改成为图像的可输入,图像也变成二维矩阵。

    Transformer Encoder

    Transformer Encoder其实就是重复堆叠Encoder Block L次,主要由Layer Norm、Multi-Head Attention、Dropout和MLP Block几部分组成。

    img

    MLP Head

    上面通过Transformer Encoder后输出的shape和输入的shape是保持不变的,以ViT-B/16为例,输入的是[197, 768]输出的还是[197, 768]。这里我们只是需要分类的信息,所以我们只需要提取出[class]token生成的对应结果就行,即[197, 768]中抽取出[class]token对应的[1, 768]。接着我们通过MLP Head得到我们最终的分类结果。MLP Head原论文中说在训练ImageNet21K时是由Linear+tanh激活函数+Linear组成。但是迁移到ImageNet1K上或者你自己的数据上时,只用一个Linear即可。

    img

    参数说明

    img

    (2)、结构说明

    • 输入图像进入网络,先经过一个卷积base的patch_embedding:取16*16为一个patch,步长也是16,这样feature map的尺寸就缩小了16倍,同时channel从3映射到768。
    • patch_embed过后加positional_embedding:positional_embedding是个可学习的参数矩阵,初始化是0。
    • 加了positional_embedding后的feature map过16个transformer block,其中12个transformer是基于window partition(就是把特征图分成14*14的windows做局部的attention)的attn模块,和4个全局attn,这4个全局attn是穿插在windowed attention中的。
    • 最后过两层卷积(neck)把channel数降到256,这就是最终的image embedding的结果

    从结构上看,sam的encoder部分就是堆叠transformer的block结构,最后再跟一个neck,调整输出embedding的维度。Meta开源了三个模型,分别是vit_h, vit_l和vit_b,这三个模型的区别仅仅在于内部patch embedding维度、transformer的block的个数以及每个block中head的数量和全局attention的index:

    img

    1.1 图片分patch

    原图进入网络之后,按照最大边长补充成方形,再resize到1024x1024。

    1024x1024x3的图片输入进入网络后,首先使用一个16x16,stride=16,输出channel数为patch embedding维度的二维卷积。以vit_b为例,patch embedding的维度是768,因此经过卷积之后,图片变成了768x64x64的feature map,再调整维度就变成64x64x768。

    在该feature map基础上,会再加一个绝对位置编码(absolute positional embedding),所谓绝对位置编码是指生成一组与feature map同样大小(64x64x768)的可学习参数,初始化时一般为0。

    1.2 attention block

    img

    img

    img

    img

    img

    二、 Prompt encoder

    根据输入的point和boxs返回sparse embedding, 根据mask返回dense embeddings。

    PromptEncoder属于轻量化的结构,用于对输入模型的points、boxes和masks信息进行编码,将其统一为空间特征编码的格式。其对points、boxes和masks编码时允许有部分值空缺(空缺使用默认值),其将points和boxes组装为sparse_embeddings,将mask组装为dense_embeddings 其对mask的采样由多个attention层实现,具体可见mask_downscaling函数

    point embedding

    • step1:首先生成一组可学习的向量point embedding,大小为:4x1x256:

    4代表了表示pos/neg + 2 box corners,即demo里面的添加点和消除点、以及box框的左上角和右下角;

    0:neg,对应demo中的消除点

    1:pos,对应demo中的添加点

    2:代表box左上角点

    3:代表box右下角点

    • step2:再生成一组可学习的向量not_a_point_embed,大小为1x256,用于表示该位置不是一个点
    • step3:如果传入的prompt里面没有bbox,则补充一个【0,0】点到每个point后面,其对应的label为-1
    • step4:如果传入的还有bbox,此时的point大小为Nx1x2,label为Nx1
    • step5:再根据point计算point embedding,其流程如下:

    横纵坐标先归一化,即都除以输入的尺寸(1024, 1024);

    再将point矩阵与一个随机高斯矩阵(2x128)矩阵相乘得到Nxax128的矩阵coord,其中(a=2表示只有point,a=1表示还有box作为prompt输入);

    再分别对coord计算sin和cos,拼接矩阵得到最终的point embedding(Nxax256)

    • step:6再根据label,给point embedding加上之前生成的可学习的embeding向量

    img

    box embedding

    bbox一般有2个点,其编码步骤如下:

    step1: 和point一样,先四个点resize为Nx2x2;

    step2: 再使用point embedding编码的方式,得到corner_embedding,

    step3: 再加上之前生成的可学习的embeding向量;

    最后输出的corner_embedding大小为Nx2x256。

    最后输出的box的embedding的尺寸是Nx2x256。

    合并(concat)point embedding和corner embedding,可以得到sparse embedding:

    全都没有:sparse embedding(1x0x256)

    如果只有point:sparse embedding(Nx2x256)

    如果只有box:sparse embedding(Nx2x256)

    piont、box都有:sparse embedding(Nx3x256)

    mask embedding

    那么对于mask这类的dense prompt,他的映射就比较简单粗暴。在输入prompt encoder之前,先要把mask降采样到4x,再过两个2x2,stride=2的卷积,这样尺寸又降了4x,就和降了16x的图像特征图尺寸一致了,再过一个1*1的卷积,把channel也升到256。如果没有提供mask,也就是我们实际inference时候的场景,这个结构会直接返回一个描述“没有mask”特征的特征图

    三、mask decoder

    img

    img

    decoder的结构之所以看起来复杂,主要原因是prompt embedding和image embedding在这个结构中反复融合并且反复更新,从这里同样可以看出prompt在这个任务中的重要地位。

    我们从左至右逐步分析decoder的流程,

    • 在prompt embedding进入decoder之前,先在它上面concat了一组可学习的output tokens,output tokens由两个部分构成:
      • 一个是iou token,它会在后面被分离出来用于预测iou的可靠性(对应结构图右侧的IoU output token),它受到模型计算出的iou与模型计算出的mask与GT实际的iou之间的MSE loss监督;
      • 另一个是mask token,它也会在后面被分离出来参与预测最终的mask(对应结构图右侧的output token per mask),mask受到focal loss和dice loss 20:1的加权组合监督。
      • 这两个token的意义我感觉比较抽象,因为理论来说进入decoder的变量应该是由模型的输入,也就是prompt和image的映射构成,但这两个token的定义与prompt和image完全没有关系,而是凭空出现的。从结果反推原因,只能把它们理解成对模型的额外约束,因为它们两个参与构成了模型的两个输出并且有loss对他们进行监督。
      • 最终prompt embedding(这一步改名叫prompt token)和刚才提到这两个token concat到一起统称为tokens进入decoder。
    • image embedding在进入decoder之前也要进行一步操作:dense prompt由于包含密集的空间信息,与image embedding所在的特征空间一致性更高,所以直接与image embedding相加融合。因为后面要与prompt做cross attention融合,这里还要先算一下image embedding的位置编码。
    • 接下来{image embedding,image embedding的位置编码,tokens}进入一个两层transformer结构的decoder做融合。值得注意的是,在transformer结构中,为了保持位置信息始终不丢失,每做一次attention运算,不管是self-attention还是cross-attention,tokens都叠加一次初始的tokens,image embedding都叠加一次它自己的位置编码,并且每个attention后边都接一个layer_norm。
      • tokens先过一个self-attention。
      • tokens作为q,对image embedding做cross attention,更新tokens。
      • tokens再过两层的mlp做特征变换。
      • image embedding作为q,对tokens做cross attention,更新image embedding。
    • 更新后的tokens作为q,再对更新后的image embedding做cross attention,产生最终的tokens。
    • 更新后的image embedding过两层kernel_size=2, stride=2的转置卷积,升采样到4x大小(依然是4x降采样原图的大小),产生最终的image embedding。
    • 接下来兵分两路:
      • mask token被从tokens中分离出来(因为他一开始就是concat上去的,可以直接按维度摘出来),过一个三层的mlp调整channel数与最终的image embedding一致,并且他们两个做矩阵乘法生成mask的预测。
      • iou token被从tokens中分离出来,也过一个三层的mlp生成最终的iou预测。
    • 最后,如前文所述,分别对mask的预测和iou预测进行监督,反向传播,更新参数。