电竞比分网-中国电竞赛事及体育赛事平台

分享

PVT:可用于密集任務(wù)backbone的金字塔視覺(jué)transformer!

 漢無(wú)為 2021-04-01

設(shè)為星標(biāo),干貨直達(dá)!

自從ViT之后,關(guān)于vision transformer的研究呈井噴式爆發(fā),從思路上分主要沿著兩大個(gè)方向,一是提升ViT在圖像分類的效果;二就是將ViT應(yīng)用在其它圖像任務(wù)中,比如分割和檢測(cè)任務(wù)上,這里介紹的PVT(Pyramid Vision Transformer) 就屬于后者。PVT相比ViT引入了和CNN類似的金字塔結(jié)構(gòu),使得PVT像CNN那樣作為backbone應(yīng)用在dense prediction任務(wù)(分割和檢測(cè)等)。

圖片

CNN結(jié)構(gòu)常用的是一種金字塔架構(gòu),如上圖所示,CNN網(wǎng)絡(luò)一般可以劃分為不同的stage,在每個(gè)stage開(kāi)始時(shí),特征圖的長(zhǎng)和寬均減半,而特征維度(channel)擴(kuò)寬2倍。這主要有兩個(gè)方面的考慮,一是采用stride=2的卷積或者池化層對(duì)特征降維可以增大感受野,另外也可以減少計(jì)算量,但同時(shí)空間上的損失用channel維度的增加來(lái)彌補(bǔ)。但是ViT本身就是全局感受野,所以ViT就比較簡(jiǎn)單直接了,直接將輸入圖像tokens化后就不斷堆積相同的transformer encoders,這應(yīng)用在圖像分類上是沒(méi)有太大的問(wèn)題。但是如果應(yīng)用在密集任務(wù)上,會(huì)遇到問(wèn)題:一是分割和檢測(cè)往往需要較大的分辨率輸入,當(dāng)輸入圖像增大時(shí),ViT的計(jì)算量會(huì)急劇上升;二是ViT直接采用較大patchs進(jìn)行token化,如采用16x16大小那么得到的粗粒度特征,對(duì)密集任務(wù)來(lái)說(shuō)損失較大。這正是PVT想要解決的問(wèn)題,PVT采用和CNN類似的架構(gòu),將網(wǎng)絡(luò)分成不同的stages,每個(gè)stage相比之前的stage特征圖的維度是減半的,這意味著tokens數(shù)量減少4倍,具體結(jié)構(gòu)如下:

圖片

每個(gè)stage的輸入都是一個(gè)維度的3-D特征圖,對(duì)于第1個(gè)stage,輸入就是RGB圖像,對(duì)于其它stage可以將tokens重新reshape成3-D特征圖。在每個(gè)stage開(kāi)始,首先像ViT一樣對(duì)輸入圖像進(jìn)行token化,即進(jìn)行patch embedding,patch大小均采用2x2大?。ǖ?個(gè)stage的patch大小是4x4),這意味著該stage最終得到的特征圖維度是減半的,tokens數(shù)量對(duì)應(yīng)減少4倍。PVT共4個(gè)stage,這和ResNet類似,4個(gè)stage得到的特征圖相比原圖大小分別是1/4,1/8,1/16和1/32。由于不同的stage的tokens數(shù)量不一樣,所以每個(gè)stage采用不同的position embeddings,在patch embed之后加上各自的position embedding,當(dāng)輸入圖像大小變化時(shí),position embeddings也可以通過(guò)插值來(lái)自適應(yīng)。

不同的stage的tokens數(shù)量不同,越靠前的stage的patchs數(shù)量越多,我們知道self-attention的計(jì)算量與sequence的長(zhǎng)度的平方成正比,如果PVT和ViT一樣,所有的transformer encoders均采用相同的參數(shù),那么計(jì)算量肯定是無(wú)法承受的。PVT為了減少計(jì)算量,不同的stages采用的網(wǎng)絡(luò)參數(shù)是不同的。PVT不同系列的網(wǎng)絡(luò)參數(shù)設(shè)置如下所示,這里為patch的size,為特征維度大小,為MHA(multi-head attention)的heads數(shù)量,為FFN的擴(kuò)展系數(shù),transformer中默認(rèn)為4。

圖片

可以見(jiàn)到隨著stage,特征的維度是逐漸增加的,比如stage1的特征維度只有64,而stage4的特征維度為512,這種設(shè)置和常規(guī)的CNN網(wǎng)絡(luò)設(shè)置是類似的,所以前面stage的patchs數(shù)量雖然大,但是特征維度小,所以計(jì)算量也不是太大。不同體量的PVT其差異主要體現(xiàn)在各個(gè)stage的transformer encoder的數(shù)量差異。

PVT為了進(jìn)一步減少計(jì)算量,將常規(guī)的multi-head attention (MHA)用spatial-reduction attention (SRA)來(lái)替換。SRA的核心是減少attention層的key和value對(duì)的數(shù)量,常規(guī)的MHA在attention層計(jì)算時(shí)key和value對(duì)的數(shù)量為sequence的長(zhǎng)度,但是SRA將其降低為原來(lái)的。SRA的具體結(jié)構(gòu)如下所示:

圖片


在實(shí)現(xiàn)上,首先將維度為的patch embeddings通過(guò) reshape變換到維度為的3-D特征圖,然后均分大小為的patchs,每個(gè)patchs通過(guò)線性變換將得到維度為的patch embeddings(這里實(shí)現(xiàn)上其實(shí)和patch emb操作類似,等價(jià)于一個(gè)卷積操作),最后應(yīng)用一個(gè)layer norm層,這樣就可以大大降低K和V的數(shù)量。具體實(shí)現(xiàn)代碼如下:

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
        super().__init__()
        assert dim % num_heads == 0f'dim {dim} should be divided by num_heads {num_heads}.'

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        # 實(shí)現(xiàn)上這里等價(jià)于一個(gè)卷積層
        if sr_ratio > 1:
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0213)

        if self.sr_ratio > 1:
            x_ = x.permute(021).reshape(B, C, H, W)
            x_ = self.sr(x_).reshape(B, C, -1).permute(021# 這里x_.shape = (B, N/R^2, C)
            x_ = self.norm(x_)
            kv = self.kv(x_).reshape(B, -12, self.num_heads, C // self.num_heads).permute(20314)
        else:
            kv = self.kv(x).reshape(B, -12, self.num_heads, C // self.num_heads).permute(20314)
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose(-2-1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(12).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

從PVT的網(wǎng)絡(luò)設(shè)置上,前面的stage的取較大的值,比如stage1的,說(shuō)明這里直接將Q和V的數(shù)量直接減為原來(lái)的1/64,這個(gè)就大大降低計(jì)算量了。

PVT具體到圖像分類任務(wù)上,和ViT一樣也通過(guò)引入一個(gè)class token來(lái)實(shí)現(xiàn)最后的分類,不過(guò)PVT是在最后的一個(gè)stage才引入:

    def forward_features(self, x):
        B = x.shape[0]

        # stage 1
        x, (H, W) = self.patch_embed1(x)
        x = x + self.pos_embed1
        x = self.pos_drop1(x)
        for blk in self.block1:
            x = blk(x, H, W)
        x = x.reshape(B, H, W, -1).permute(0312).contiguous()

        # stage 2
        x, (H, W) = self.patch_embed2(x)
        x = x + self.pos_embed2
        x = self.pos_drop2(x)
        for blk in self.block2:
            x = blk(x, H, W)
        x = x.reshape(B, H, W, -1).permute(0312).contiguous()

        # stage 3
        x, (H, W) = self.patch_embed3(x)
        x = x + self.pos_embed3
        x = self.pos_drop3(x)
        for blk in self.block3:
            x = blk(x, H, W)
        x = x.reshape(B, H, W, -1).permute(0312).contiguous()

        # stage 4
        x, (H, W) = self.patch_embed4(x)
        cls_tokens = self.cls_token.expand(B, -1-1# 引入class token
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed4
        x = self.pos_drop4(x)
        for blk in self.block4:
            x = blk(x, H, W)

        x = self.norm(x)

        return x[:, 0]

具體到分類任務(wù)上,PVT在ImageNet上的Top-1 Acc其實(shí)是和ViT差不多的。其實(shí)PVT最重要的應(yīng)用是作為dense任務(wù)如分割和檢測(cè)的backbone,一方面PVT通過(guò)一些巧妙的設(shè)計(jì)使得對(duì)于分辨率較大的輸入圖像,其模型計(jì)算量不像ViT那么大,論文中比較了ViT-Small/16 ,ViT-Small,PVT-Small和ResNet50四種網(wǎng)絡(luò)在不同的輸入scale下的GFLOPs,可以看到PVT相比ViT要好不少,當(dāng)輸入scale=640時(shí),PVT-Small和ResNet50的計(jì)算量是類似的,但是如果到更大的scale,PVT的增長(zhǎng)速度就遠(yuǎn)超過(guò)ResNet50了。

圖片

PVT的另外一個(gè)相比ViT的優(yōu)勢(shì)就是其可以輸出不同scale的特征圖,這對(duì)于分割和檢測(cè)都是非常重要的。因?yàn)槟壳按蟛糠值姆指詈蜋z測(cè)模型都是采用FPN結(jié)構(gòu),而PVT這個(gè)特性可以使其作為替代CNN的backbone而無(wú)縫對(duì)接分割和檢測(cè)的heads。論文中做了大量的關(guān)于檢測(cè),語(yǔ)義分割以及實(shí)例分割的實(shí)驗(yàn),可以看到PVT在dense任務(wù)的優(yōu)勢(shì)。比如,在更少的推理時(shí)間內(nèi),基于PVT-Small的RetinaNet比基于R50的RetinaNet在COCO上的AP值更高(38.7 vs. 36.3),雖然繼續(xù)增加scale可以提升效果,但是就需要額外的推理時(shí)間:

圖片


所以雖然PVT可以解決一部分問(wèn)題,但是如果輸入圖像分辨率特別大,可能基于CNN的方案還是最優(yōu)的。另外曠視最新的一篇論文YOLOF指出其實(shí)ResNet一個(gè)C5特征加上一些增大感受野的模塊就可以在檢測(cè)上實(shí)現(xiàn)類似的效果,這不得不讓人思考多尺度特征是不是必須的,而且transformer encoder本身就是全局感受野的。近期Intel提出的DPT直接在ViT模型的基礎(chǔ)上通過(guò)Reassembles operation來(lái)得到不同scale的特征圖以用于dense任務(wù),并在ADE20K語(yǔ)義分割數(shù)據(jù)集上達(dá)到新的SOTA(mIoU 49.02)。而在近日,微軟提出的Swin Transformer和PVT的網(wǎng)絡(luò)架構(gòu)和很類似,但其性能在各個(gè)檢測(cè)和分割數(shù)據(jù)集上效果達(dá)到SOTA(在ADE20K語(yǔ)義分割數(shù)據(jù)集mIoU 53.5),其核心提出了一種shifted window方法來(lái)減少self-attention的計(jì)算量。

相信未來(lái)會(huì)有更好的work!期待!

參考

  1. Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions
  2. whai362/PVT
  3. 大白話Pyramid Vision Transformer
  4. You Only Look One-level Feature
  5. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
  6. Vision Transformers for Dense Prediction

    本站是提供個(gè)人知識(shí)管理的網(wǎng)絡(luò)存儲(chǔ)空間,所有內(nèi)容均由用戶發(fā)布,不代表本站觀點(diǎn)。請(qǐng)注意甄別內(nèi)容中的聯(lián)系方式、誘導(dǎo)購(gòu)買等信息,謹(jǐn)防詐騙。如發(fā)現(xiàn)有害或侵權(quán)內(nèi)容,請(qǐng)點(diǎn)擊一鍵舉報(bào)。
    轉(zhuǎn)藏 分享 獻(xiàn)花(0

    0條評(píng)論

    發(fā)表

    請(qǐng)遵守用戶 評(píng)論公約

    類似文章 更多