1、【Loss】Centerloss代码详解(pytorch)*注:*全部代码在最后,此代码不知来哪位神。m:batch size n:class size d:feat dimtorch.pow(x,2).sum(dim=1,keepdim=True).expand(batch_size,self.num_classes)同理self.centers=nn.Parameter(torch.randn(self.num_classes,self.feat_dim).cuda()torch.pow(self.centers,2).sum(dim=1,keepdim=True).expand(self
2、.num_classes,batch_size).t()x=x0 x1.xm1x00 x10.x(m1)0 x01x11.x(m1)1.x0(d1)x1(d1).x(m1)(d1)Rmd=x02x12.xm12x+x+.+x0020120(d1)2x+x+.+x1021121(d1)2.x+x+.+x(m1)02(m1)12(m1)(d1)2Rm1=x02x12.xm12x02x12.xm12.x02x12.xm12Rmncenters=c00c10.c(n1)0c01c11.c(n1)1.c0(d1)c1(d1).c(n1)(d1)Rnd=12c+c+.+c0020120(d1)2c+c+.
3、+c1021121(d1)2.c+c+.+c(n1)02(n1)12(n1)(d1)2Rn1=12Rnm=12Rmndistmat=torch.pow(x,2).sum(dim=1,keepdim=True).expand(batch_size,self.num_classes)+torch.pow(self.centers,2).sum(dim=1,keepdim=True).expand(self.num_classes,batch_size).t()distmat.addmm_(1,-2,x,self.centers.t()假设,可以看作样本与类中之间的距离,再结合后的mask。clas
4、ses=torch.arange(self.num_classes).long()classes=0,1,n-1输的 对应每个样本的类别,labels=labels.unsqueeze(1).expand(batch_size,self.num_classes)unsqueeze(1):expand(batch_size,self.num_classes):每元素相同mask=labels.eq(classes.expand(batch_size,self.num_classes)将labels转换成one-hot?example:batch size=3 num class=4下这部分理解的
5、是通过mask 找到中点,并且不断减样本与其对应类别中之间的距离。dismat=+12Rmndismat=distmat 2 x centersT=+122x00 x10.x(m1)0 x01x11.x(m1)1.x0(d1)x1(d1).x(m1)(d1)c00c10.c(n1)0c01c11.c(n1)1.c0(d1)c1(d1).c(n1)(d1)TRmndismat=d,0 iji m 1,0 j n 1xicjlabels Rmlabels Rm1labels Rmnlabels=210210210210classes.expand=000111222333mask=FalseFal
6、seTrueFalseTrueFalseTrueFalseFalseFalseFalseFalsedist=for i in range(batch_size):value=distmatimaski value=value.clamp(min=1e-12,max=1e+12)#for numerical stability dist.append(value)dist=torch.cat(dist)loss=dist.mean()return loss全部代码class CenterLoss(nn.Module):Center loss.Reference:Wen et al.A Discr
7、iminative Feature Learning Approach for Deep Face Recognition.ECCV 2016.Args:num_classes(int):number of classes.feat_dim(int):feature dimension.def _init_(self,num_classes=751,feat_dim=2048,use_gpu=True):super(CenterLoss,self)._init_()self.num_classes=num_classes self.feat_dim=feat_dim self.use_gpu=
8、use_gpu if self.use_gpu:self.centers=nn.Parameter(torch.randn(self.num_classes,self.feat_dim).cuda()else:self.centers=nn.Parameter(torch.randn(self.num_classes,self.feat_dim)def forward(self,x,labels):Args:x:feature matrix with shape(batch_size,feat_dim).labels:ground truth labels with shape(num_cla
9、sses).应该是batch_size assert x.size(0)=labels.size(0),features.size(0)is not equal to labels.size(0)batch_size=x.size(0)#distmat=torch.pow(x,2).sum(dim=1,keepdim=True).expand(batch_size,self.num_classes)+torch.pow(self.centers,2).sum(dim=1,keepdim=True).expand(self.num_classes,batch_size).t()#1*distma
10、t-2*x*centers.t()distmat.addmm_(1,-2,x,self.centers.t()classes=torch.arange(self.num_classes).long()if self.use_gpu:classes=classes.cuda()labels=labels.unsqueeze(1).expand(batch_size,self.num_classes)mask=labels.eq(classes.expand(batch_size,self.num_classes)dist=for i in range(batch_size):value=distmatimaski value=value.clamp(min=1e-12,max=1e+12)#for numerical stability dist.append(value)dist=torch.cat(dist)loss=dist.mean()return loss