电竞比分网-中国电竞赛事及体育赛事平台

分享

ICCV 2021 | 最快視覺Transformer!Facebook提出LeViT:快速推理的視...

 taotao_2016 2021-07-30

AI/CV重磅干貨,第一時間送達

CVer
CVer
一個專注侃侃計算機視覺方向的公眾號。計算機視覺、圖像處理、機器學習、深度學習、C/C++、Python、詩和遠方等。
204篇原創(chuàng)內容
公眾號
本文轉載自:集智書童
圖片

LeViT: a Vision Transformer in ConvNet’s Clothing for Faster Inference

論文:https:///abs/2104.01136

代碼(剛剛開源):

https://github.com/facebookresearch/LeViT

吸取CNN優(yōu)點!LeViT:快速推理的視覺Transformer,在速度/準確性的權衡方面LeViT明顯優(yōu)于現有的CNN和視覺Transformer,比如ViT、DeiT等,而且top-1精度為80%的情況下LeViT比CPU上的EfficientNet快3.3倍。
作者單位:Facebook

1 簡介

本文的工作利用了基于注意力體系結構中的最新發(fā)現,該體系結構在高度并行處理硬件上具有競爭力。作者從卷積神經網絡的大量文獻中重新評估了原理,以將其應用于Transformer,尤其是分辨率降低的激活圖。同時作者還介紹了Attention bias,一種將位置信息集成到視覺Transformer中的新方法。

圖片
圖1 LeViT性能對比

最終作者提出了LeVIT:一種用于快速推理的混合神經網絡??紤]在不同的硬件平臺上采用不同的效率衡量標準,以最好地反映各種應用場景。作者通過廣泛的實驗表明該方法適用于大多數體系結構??傮w而言,在速度/準確性的權衡方面,LeViT明顯優(yōu)于現有的卷積網絡和視覺Transformer。例如,在ImageNet Top-1精度為80%的情況下,LeViT比CPU上的EfficientNet快3.3倍。

相同計算復雜度的情況下Transformer為什么快?

大多數硬件加速器(gpu,TPUs)被優(yōu)化以用來執(zhí)行大型矩陣乘法。在Transformer中,注意力機制和MLP塊主要依靠這些操作。相比之下,卷積需要復雜的數據訪問模式,因此它們的操作通常受io約束。這些考慮對于我們探索速度/精度的權衡是很重要的。

本文主要貢獻:

  1. 采用注意力機制作為下采樣機制的multi-stage transformer 結構;

  2. 一種計算效率高的patch descriptor,可以減少第一層特征的數量;

  3. 使用Translation-invariant attention bias取代ViT中的位置嵌入;

  4. 為了提高給定計算時間的網絡容量,作者重新設計了Attention-MLP Block。

2 LeViT的設計

2.1 LeViT設計原則

LeViT以ViT的架構和DeiT的訓練方法為基礎,合并了對卷積架構有用的組件。第1步是獲得Compatible Representation。如果不考慮classification embedding的作用,ViT就是一個處理激活映射的Layer的堆疊。

圖片

實際上,中間“Token”嵌入可以看作是FCN體系結構中傳統(tǒng)的C×H×W激活映射(BCHW格式)。因此,適用于激活映射(池、卷積)的操作可以應用于DeiT的中間表征。

LeViT優(yōu)化了計算體系結構,不一定是為了最小化參數的數量。ResNet系列比VGG更高效的設計原則之一是在其前2個階段使用相對較小的計算預算應用strong resolution reductions。當激活映射到達ResNet的第3階段時,其分辨率已經縮小到足以將卷積應用于小的激活映射,從而降低了計算成本。

2.2 LeViT組件

1、Patch embedding

初步分析表明,在transformer組的輸入上應用一個小卷積可以提高精度。因此在LeViT中作者選擇對輸入應用4層3×3卷積(stride2)來降低分辨率。channel的數量是C=3,32,64,128,256。

以上操作減少了對transformer下層的激活映射的輸入,同時不丟失重要信息。LeViT-256的patch extractor用184 MFLOPs將圖像形狀(3,224,224)轉換為(256,14,14)。作為比較,ResNet-18的前10層使用1042 MFLOPs執(zhí)行相同的dimensionality reduction。

為什么在transformer組的輸入上應用一個小卷積可以提高精度?

圖片

2、No classification token

為了使用BCHW張量形式,LeViT刪除了classification token。類似于卷積網絡,在最后一個激活映射上使用GAP來代替,這將產生一個用于分類器的embedding。在訓練中進行蒸餾,作者分別訓練分類和蒸餾的Head。在測試時,平均2個Head的輸出。在實踐中,LeViT可以使用BNC或BCHW張量格式。

3、Normalization layers and activations

ViT架構中的FC層相當于1x1卷積。ViT在每個注意點和MLP單元之前使用層歸一化。對于LeViT,每次卷積之后都要進行BN操作。然后與residual connection連接起來的每個BN權重參數初始化為零。BN可以與之前的卷積合并來進行推理,這比層歸一化有運行優(yōu)勢(例如,在EfficientNet B0上,這種融合將GPU的推理速度提高了2倍)。而DeiT使用GELU函數,而LeViT的非線性激活都是Hardswish。

class Linear_BN(torch.nn.Sequential):
    def __init__(self, a, b, bn_weight_init=1, resolution=-100000):
        super().__init__()
        self.add_module('c', torch.nn.Linear(a, b, bias=False))
        bn = torch.nn.BatchNorm1d(b)
        torch.nn.init.constant_(bn.weight, bn_weight_init)
        torch.nn.init.constant_(bn.bias, 0)
        self.add_module('bn', bn)

        global FLOPS_COUNTER
        output_points = resolution**2
        FLOPS_COUNTER += a * b * output_points

    @torch.no_grad()
    def fuse(self):
        l, bn = self._modules.values()
        w = bn.weight / (bn.running_var + bn.eps)**0.5
        w = l.weight * w[:, None]
        b = bn.bias - bn.running_mean * bn.weight / \
            (bn.running_var + bn.eps)**0.5
        m = torch.nn.Linear(w.size(1), w.size(0))
        m.weight.data.copy_(w)
        m.bias.data.copy_(b)
        return m

    def forward(self, x):
        l, bn = self._modules.values()
        x = l(x)
        return bn(x.flatten(01)).reshape_as(x)

4、Multi-resolution pyramid

LeViT在transformer架構中集成了ResNet stage。在各個stage中,該體系結構類似于一個visual transformer:一個帶有交替MLP和激活塊的殘差模塊。下面是注意塊的修改。

圖片
class Attention(torch.nn.Module):
    def __init__(self, dim, key_dim, num_heads=8,
                 attn_ratio=4,
                 activation=None,
                 resolution=14)
:

        super().__init__()
        self.num_heads = num_heads
        self.scale = key_dim ** -0.5
        self.key_dim = key_dim
        self.nh_kd = nh_kd = key_dim * num_heads
        self.d = int(attn_ratio * key_dim)
        self.dh = int(attn_ratio * key_dim) * num_heads
        self.attn_ratio = attn_ratio
        h = self.dh + nh_kd * 2
        self.qkv = Linear_BN(dim, h, resolution=resolution)
        self.proj = torch.nn.Sequential(activation(), Linear_BN(
            self.dh, dim, bn_weight_init=0, resolution=resolution))

        points = list(itertools.product(range(resolution), range(resolution)))
        N = len(points)
        attention_offsets = {}
        idxs = []
        for p1 in points:
            for p2 in points:
                offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
                if offset not in attention_offsets:
                    attention_offsets[offset] = len(attention_offsets)
                idxs.append(attention_offsets[offset])
        self.attention_biases = torch.nn.Parameter(
            torch.zeros(num_heads, len(attention_offsets)))
        self.register_buffer('attention_bias_idxs',
                             torch.LongTensor(idxs).view(N, N))

        global FLOPS_COUNTER
        #queries * keys
        FLOPS_COUNTER += num_heads * (resolution**4) * key_dim
        # softmax
        FLOPS_COUNTER += num_heads * (resolution**4)
        #attention * v
        FLOPS_COUNTER += num_heads * self.d * (resolution**4)

    @torch.no_grad()
    def train(self, mode=True):
        super().train(mode)
        if mode and hasattr(self, 'ab'):
            del self.ab
        else:
            self.ab = self.attention_biases[:, self.attention_bias_idxs]

    def forward(self, x):  # x (B,N,C)
        B, N, C = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.view(B, N, self.num_heads, -
                           1).split([self.key_dim, self.key_dim, self.d], dim=3)
        q = q.permute(0213)
        k = k.permute(0213)
        v = v.permute(0213)

        attn = (
            (q @ k.transpose(-2-1)) * self.scale
            +
            (self.attention_biases[:, self.attention_bias_idxs]
             if self.training else self.ab)
        )
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(12).reshape(B, N, self.dh)
        x = self.proj(x)
        return x

5、Downsampling

在LeViT stage之間,一個縮小的注意塊減少了激活映射的大小:在Q轉換之前應用一個subsampling,然后傳播到soft activation的輸出。這將一個大小為的輸入張量映射到一個大小為的輸出張量。由于尺度的變化這個注意塊的使用沒有殘差連接。同時為了防止信息丟失,這里將注意力頭的數量設為

class Subsample(torch.nn.Module):
    def __init__(self, stride, resolution):
        super().__init__()
        self.stride = stride
        self.resolution = resolution

    def forward(self, x):
        B, N, C = x.shape
        x = x.view(B, self.resolution, self.resolution, C)[
            :, ::self.stride, ::self.stride].reshape(B, -1, C)
        return x

6、Attention bias instead of a positional embedding

在transformer架構中的位置嵌入是一個位置依賴可訓練的向量,在將token嵌入輸入到transformer塊之前,將其添加到token嵌入。如果沒有它,轉換器輸出將獨立于輸入標記的排列。位置嵌入的Ablations會導致分類精度的急劇下降。

然而,位置嵌入只包含在注意塊序列的輸入上。因此,由于位置編碼對higher layer也很重要,所以它很可能仍然處于中間表示中。

因此,LeViT在每個注意塊中提供位置信息,并在注意機制中明確地注入相對位置信息:只是在注意力圖中添加了注意偏向。對于每個head ,每2個像素之間的標量值計算方式為:

圖片

第一項是經典的注意力。第二個是translation-invariant attention bias。每個Head有H×W參數對應不同的像素偏移量。對稱差異鼓勵用 flip invariance進行訓練。

self.attention_biases = torch.nn.Parameter(
            torch.zeros(num_heads, len(attention_offsets)))

7、Smaller keys

由于translation-invariant attention bias偏置項減少了key對位置信息編碼的壓力,因此LeViT減少了key矩陣相對于V矩陣的大小。如果key大小為, V則有2D通道。key的大小可以減少計算key product 所需的時間。

對于沒有殘差連接的下采樣層,將V的維數設置為4D,以防止信息丟失。

8、Attention activation

在使用常規(guī)線性投影組合不同Heads的輸出之前,對product 應用Hardswish激活。這類似于ResNet bottleneck residual block,V是一個1×1卷積的輸出,對應一個spatial卷積,projection是另一個1×1卷積。

9、Reducing the MLP blocks

在ViT中,MLP residual塊是一個線性層,它將嵌入維數增加了4倍,然后用一個非線性將其減小到原來的嵌入維數。但是對于視覺架構,MLP通常在運行時間和參數方面比注意Block更昂貴。

對于LeViT, MLP是1x1卷積,然后是通常的BN。為了減少計算開銷,將卷積的展開因子從4降低到2。一個設計目標是注意力和MLP塊消耗大約相同數量的FLOPs。

2.3 LeViT家族

圖片

3 實驗

3.1 速度對比

圖片

ResNet50的精度,但是是起飛的速度。

3.2 SOTA對比

圖片


論文PDF和代碼下載

    本站是提供個人知識管理的網絡存儲空間,所有內容均由用戶發(fā)布,不代表本站觀點。請注意甄別內容中的聯(lián)系方式、誘導購買等信息,謹防詐騙。如發(fā)現有害或侵權內容,請點擊一鍵舉報。
    轉藏 分享 獻花(0

    0條評論

    發(fā)表

    請遵守用戶 評論公約

    類似文章 更多