YOLOv10模型改进-Backbone改进-第60篇: YOLOv10改进策略【Backbone】| PVT Backbone替换 一、本文介绍本文记录的是利用PVTPyramid Vision Transformer作为Backbone改进YOLOv10的特征提取部分。PVT通过金字塔结构和空间缩减注意力实现高效的多尺度特征提取。二、PVT模块介绍2.1 设计出发点ViT缺乏多尺度特征提取能力PVT通过金字塔结构和空间缩减注意力同时兼顾全局建模和多尺度特征。2.2 模块结构PVT块空间缩减注意力减少注意力计算复杂度前馈网络非线性变换层次化设计多尺度特征输出三、PVT的实现代码importtorchimporttorch.nnasnnclassSpatialReductionAttention(nn.Module):def__init__(self,dim,num_heads4,sr_ratio1):super().__init__()self.num_headsnum_heads self.scale(dim//num_heads)**-0.5self.qnn.Linear(dim,dim)self.kvnn.Linear(dim,dim*2)self.projnn.Linear(dim,dim)self.sr_ratiosr_ratioifsr_ratio1:self.srnn.Conv2d(dim,dim,sr_ratio,sr_ratio)self.normnn.LayerNorm(dim)defforward(self,x,H,W):B,N,Cx.shape qself.q(x).reshape(B,N,self.num_heads,C//self.num_heads).permute(0,2,1,3)ifself.sr_ratio1:x_x.transpose(1,2).view(B,C,H,W)x_self.sr(x_).reshape(B,C,-1).transpose(1,2)x_self.norm(x_)kvself.kv(x_).reshape(B,-1,2,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)else:kvself.kv(x).reshape(B,N,2,self.num_heads,C//self.num_heads).permute(2,0,3,1,4)k,vkv[0],kv[1]attn(q k.transpose(-2,-1))*self.scale attnattn.softmax(dim-1)x(attn v).transpose(1,2).reshape(B,N,C)returnself.proj(x)classPVTBasicLayer(nn.Module):def__init__(self,dim,num_heads,sr_ratio1):super().__init__()self.norm1nn.LayerNorm(dim)self.attnSpatialReductionAttention(dim,num_heads,sr_ratio)self.norm2nn.LayerNorm(dim)self.mlpnn.Sequential(nn.Linear(dim,dim*4),nn.GELU(),nn.Linear(dim*4,dim))defforward(self,x,H,W):xxself.attn(self.norm1(x),H,W)xxself.mlp(self.norm2(x))returnxclassPVT(nn.Module):def__init__(self,c13,c21024,embed_dims[64,128,256,512],num_heads[1,2,4,8],sr_ratios[8,4,2,1]):super().__init__()self.patch_embedsnn.ModuleList()self.patch_embeds.append(nn.Sequential(nn.Conv2d(c1,embed_dims[0],7,4,3),nn.LayerNorm(embed_dims[0])))foriinrange(1,4):self.patch_embeds.append(nn.Sequential(nn.Conv2d(embed_dims[i-1],embed_dims[i],3,2,1),nn.LayerNorm(embed_dims[i])))self.layersnn.ModuleList()foriinrange(4):self.layers.append(PVTBasicLayer(embed_dims[i],num_heads[i],sr_ratios[i]))self.final_convnn.Conv2d(embed_dims[-1],c2,1,biasFalse)defforward(self,x):Bx.shape[0]fori,embedinenumerate(self.patch_embeds):xembed(x)H,Wx.shape[2:]xx.flatten(2).transpose(1,2)xself.layers[i](x,H,W)ifi3:xx.transpose(1,2).reshape(B,-1,H,W)xx.transpose(1,2).reshape(B,-1,H,W)xself.final_conv(x)returnx四、创新模块将PVT作为Backbone集成到YOLOv10中# yolov10n_pvt.yamlbackbone:-[-1,1,PVT,[3,1024]]-[-1,1,SPPF,[1024,5]]五、预期结果模型mAP0.5mAP0.5:0.95参数量YOLOv10n52.3%27.9%2.7MYOLOv10n-PVT53.2%28.8%13.0M项目环境配置Python3.8.10PyTorch2.0.0CUDA11.8Ultralytics8.3.13