Loading...
墨滴

咚咚

2021/11/30  阅读:73  主题:默认主题

自监督学习MOCO算法解析

代码:Momentum Contrast for Unsupervised Visual Representation Learning

数据增强

输入图像x大小设为(b, c, h, w)

通过对输入图像x进行以下数据增强方式:

随机裁剪到224*224大小->RandomGrayscale->ColorJitter->RandomHorizontalFlip

获取两个 增强视觉, 大小均为(b, c, 224, 224)

网络

整体网络结构如上图所示

具有左右两分支网络,其中左分支中query encoder和右分支中key encoder结构一样

右分支还多一个queue模块

encoder

encoder是常用的backbone结构,例如resnet、vgg、alexnet等, 输出特征q大小为(b, cls)

而对于右分支key encoder(即图中的momentum encoder),需要先对输入图像 进行shuffle batch操作得到新的输入图像


shuffle batch: 只针对多GPU,而且是使用ddp

假如使用4个GPUs,每个GPU中的batchsize=8, 每个GPU中特征大小为(8, c, 224, 224)

那么所有输入特征x_gather大小为(32, c, 224, 224)

总batchsize=4*8=32

对32个数进行shuffle操作,得到不连续的32个数值idx_shuffle,例如[5, 2, 7, 8, 1, 0, ...], 大小为(32,), 可view成(4, 8)大小, 那么每个GPU中可取值x_gather[idx_shuffle[i]], 获得新的大小为(8, c, 224, 224)的输入图像, 这时的图像与 就不一定是来自同一张图像

代码如下:

    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """

        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda()

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle


将经过shuffle batch的新的输入图像 送入momentum encoder, 得到输出特征k, 大小为(b, cls)

再对输出特征k进行unshuffle batch,与shuffle batch 互为逆操作,得到与q来自同一个图像的特征k

queue

随机初始化一个大小为(cls, K)的队列queue,K为队列长度, cls是维度大小

并且确定队列初始位置索引queue_ptr, 初始值为0

每次训练一次结束,都将momentum encoder的输出特征k添加到队列尾部,并剔除队头,保持整个队列的长度不变

代码如下:

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        # gather keys before updating queue
        keys = concat_all_gather(keys)

        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)  # 获取队列头部索引位置
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size] = keys.T  # 添加momentum encoder的输出特征k
        ptr = (ptr + batch_size) % self.K  # move pointer 移动队列头部索引位置,也就是删除头部多余位置

        self.queue_ptr[0] = ptr

Loss

将k作为q的正样本,因为k与q是来自同一张图像的不同视图

将queue作为q的负样本,因为queue中含有大量不同图像的视图

首先计算正样本损失l_pos, 大小为(b, 1)

        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)

再计算负样本损失l_neg, 大小为(b, K)

        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

将l_pos和l_neg进行cat操作,并除以温度参数temperature(控制concentration level of distribution), 得到logits, 大小为(b, 1+K)

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

目标是正样本都为1,负样本都为0

那么可以把logits看做分类,分成1+K个类别,期望都是第一个类别,则可以把labels设为0

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

最后使用nn.CrossEntropyLoss计算损失函数

梯度反传

右分支网络不参与直接训练,其中所有的权重参数不具有梯度值。其参数 更新方式是基于左分支网络参数 动量更新。训练开始前,两分支网络初始权重保持一致。

其中,momentum是动量值

咚咚

2021/11/30  阅读:73  主题:默认主题

作者介绍

咚咚