欢迎来到冰点文库! | 帮助中心 分享价值,成长自我!
冰点文库
全部分类
  • 临时分类>
  • IT计算机>
  • 经管营销>
  • 医药卫生>
  • 自然科学>
  • 农林牧渔>
  • 人文社科>
  • 工程科技>
  • PPT模板>
  • 求职职场>
  • 解决方案>
  • 总结汇报>
  • ImageVerifierCode 换一换
    首页 冰点文库 > 资源分类 > PDF文档下载
    分享到微信 分享到微博 分享到QQ空间

    LossCenterloss代码详解pytorch.pdf

    • 资源ID:14659184       资源大小:99.27KB        全文页数:3页
    • 资源格式: PDF        下载积分:8金币
    快捷下载 游客一键下载
    账号登录下载
    微信登录下载
    三方登录下载: 微信开放平台登录 QQ登录
    二维码
    微信扫一扫登录
    下载资源需要8金币
    邮箱/手机:
    温馨提示:
    快捷下载时,用户名和密码都是您填写的邮箱或者手机号,方便查询和重复下载(系统自动生成)。
    如填写123,账号就是123,密码也是123。
    支付方式: 支付宝    微信支付   
    验证码:   换一换

    加入VIP,免费下载
     
    账号:
    密码:
    验证码:   换一换
      忘记密码?
        
    友情提示
    2、PDF文件下载后,可能会被浏览器默认打开,此种情况可以点击浏览器菜单,保存网页到桌面,就可以正常下载了。
    3、本站不支持迅雷下载,请使用电脑自带的IE浏览器,或者360浏览器、谷歌浏览器下载即可。
    4、本站资源下载后的文档和图纸-无水印,预览文档经过压缩,下载后原文更清晰。
    5、试题试卷类文档,如果标题没有明确说明有答案则都视为没有答案,请知晓。

    LossCenterloss代码详解pytorch.pdf

    1、【Loss】Centerloss代码详解(pytorch)*注:*全部代码在最后,此代码不知来哪位神。m:batch size n:class size d:feat dimtorch.pow(x,2).sum(dim=1,keepdim=True).expand(batch_size,self.num_classes)同理self.centers=nn.Parameter(torch.randn(self.num_classes,self.feat_dim).cuda()torch.pow(self.centers,2).sum(dim=1,keepdim=True).expand(self

    2、.num_classes,batch_size).t()x=x0 x1.xm1x00 x10.x(m1)0 x01x11.x(m1)1.x0(d1)x1(d1).x(m1)(d1)Rmd=x02x12.xm12x+x+.+x0020120(d1)2x+x+.+x1021121(d1)2.x+x+.+x(m1)02(m1)12(m1)(d1)2Rm1=x02x12.xm12x02x12.xm12.x02x12.xm12Rmncenters=c00c10.c(n1)0c01c11.c(n1)1.c0(d1)c1(d1).c(n1)(d1)Rnd=12c+c+.+c0020120(d1)2c+c+.

    3、+c1021121(d1)2.c+c+.+c(n1)02(n1)12(n1)(d1)2Rn1=12Rnm=12Rmndistmat=torch.pow(x,2).sum(dim=1,keepdim=True).expand(batch_size,self.num_classes)+torch.pow(self.centers,2).sum(dim=1,keepdim=True).expand(self.num_classes,batch_size).t()distmat.addmm_(1,-2,x,self.centers.t()假设,可以看作样本与类中之间的距离,再结合后的mask。clas

    4、ses=torch.arange(self.num_classes).long()classes=0,1,n-1输的 对应每个样本的类别,labels=labels.unsqueeze(1).expand(batch_size,self.num_classes)unsqueeze(1):expand(batch_size,self.num_classes):每元素相同mask=labels.eq(classes.expand(batch_size,self.num_classes)将labels转换成one-hot?example:batch size=3 num class=4下这部分理解的

    5、是通过mask 找到中点,并且不断减样本与其对应类别中之间的距离。dismat=+12Rmndismat=distmat 2 x centersT=+122x00 x10.x(m1)0 x01x11.x(m1)1.x0(d1)x1(d1).x(m1)(d1)c00c10.c(n1)0c01c11.c(n1)1.c0(d1)c1(d1).c(n1)(d1)TRmndismat=d,0 iji m 1,0 j n 1xicjlabels Rmlabels Rm1labels Rmnlabels=210210210210classes.expand=000111222333mask=FalseFal

    6、seTrueFalseTrueFalseTrueFalseFalseFalseFalseFalsedist=for i in range(batch_size):value=distmatimaski value=value.clamp(min=1e-12,max=1e+12)#for numerical stability dist.append(value)dist=torch.cat(dist)loss=dist.mean()return loss全部代码class CenterLoss(nn.Module):Center loss.Reference:Wen et al.A Discr

    7、iminative Feature Learning Approach for Deep Face Recognition.ECCV 2016.Args:num_classes(int):number of classes.feat_dim(int):feature dimension.def _init_(self,num_classes=751,feat_dim=2048,use_gpu=True):super(CenterLoss,self)._init_()self.num_classes=num_classes self.feat_dim=feat_dim self.use_gpu=

    8、use_gpu if self.use_gpu:self.centers=nn.Parameter(torch.randn(self.num_classes,self.feat_dim).cuda()else:self.centers=nn.Parameter(torch.randn(self.num_classes,self.feat_dim)def forward(self,x,labels):Args:x:feature matrix with shape(batch_size,feat_dim).labels:ground truth labels with shape(num_cla

    9、sses).应该是batch_size assert x.size(0)=labels.size(0),features.size(0)is not equal to labels.size(0)batch_size=x.size(0)#distmat=torch.pow(x,2).sum(dim=1,keepdim=True).expand(batch_size,self.num_classes)+torch.pow(self.centers,2).sum(dim=1,keepdim=True).expand(self.num_classes,batch_size).t()#1*distma

    10、t-2*x*centers.t()distmat.addmm_(1,-2,x,self.centers.t()classes=torch.arange(self.num_classes).long()if self.use_gpu:classes=classes.cuda()labels=labels.unsqueeze(1).expand(batch_size,self.num_classes)mask=labels.eq(classes.expand(batch_size,self.num_classes)dist=for i in range(batch_size):value=distmatimaski value=value.clamp(min=1e-12,max=1e+12)#for numerical stability dist.append(value)dist=torch.cat(dist)loss=dist.mean()return loss


    注意事项

    本文(LossCenterloss代码详解pytorch.pdf)为本站会员主动上传,冰点文库仅提供信息存储空间,仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。 若此文所含内容侵犯了您的版权或隐私,请立即通知冰点文库(点击联系客服),我们立即给予删除!

    温馨提示:如果因为网速或其他原因下载失败请重新下载,重复下载不扣分。




    关于我们 - 网站声明 - 网站地图 - 资源地图 - 友情链接 - 网站客服 - 联系我们

    copyright@ 2008-2023 冰点文库 网站版权所有

    经营许可证编号:鄂ICP备19020893号-2


    收起
    展开