PolaLinearAtt如何平衡性能与效率?

酥酥 发布于 2025-10-20 111 次阅读


创新点

  • 提出了 极性感知线性注意力 (Polarity-aware Linear Attention),显式建模 query-key 的正负交互,解决了传统线性注意力丢失负值信息的问题

  • 引入 可学习的极性混合系数矩阵 (Gs, Go),替代简单的减法操作,稳定地结合同号与异号交互

  • 理论上证明了通过 正的一阶与二阶导数的函数 可以降低注意力分布的熵,从而恢复 softmax 的“尖锐性”

  • 实现了一个 逐通道可学习幂函数 rescaling,既保证了非负性,又恢复了 softmax 类似的稀疏性

  • 在 ImageNet、COCO、ADE20K、LRA 等多任务上验证了性能与效率的兼顾,提升最高达 +4.6%

方法

整体结构
      PolaFormer 在 Transformer 中用 极性感知线性注意力 取代标准注意力:它将 Query-Key 分解为正负部分,显式建模同号与异号的交互关系,并通过可学习极性混合矩阵进行融合;同时,利用逐通道的幂函数缩放恢复 softmax 的尖锐分布特性,并结合卷积提升秩表示,从而在保持线性复杂度的同时,实现更接近原始 softmax 的表达能力。
  • Query-Key 分解:将 Q、K 拆分为正/负部分,得到四类交互(同号:正-正、负-负;异号:正-负、负-正)

  • 双流处理:Value 向量在通道维度一分为二,分别对应同号交互流与异号交互流

  • 极性系数矩阵 Gs/Go:对两个流的输出加权融合,学习两类交互的互补性

  • 幂函数缩放:对 Q/K 映射引入逐通道可学习幂次 rescaling,降低熵,使注意力分布更尖锐

  • 卷积增强:可选引入 DWC/DCN 升秩,避免低秩退化

即插即用模块作用

PolaLinearAtt 适合在高效 Transformer 中即插即用,尤其在高分辨率视觉与长序列任务中,能以线性复杂度恢复更接近 softmax 的判别力与稀疏性

适用场景:
  • 高分辨率视觉任务:如目标检测、语义分割、实例分割,需要在大图像上高效建模全局依赖

  • 长序列任务:如长视频建模、长文本处理(LRA 等),线性复杂度能有效降低显存和计算压力

  • 资源受限环境:移动端或边缘设备部署,既要节省算力又要保持较好性能

  • 对细粒度关系敏感的任务:需要区分强/弱相关性或正/负相关性的场景(如小目标检测、细粒度分类)

模块作用:
  • 补全负值交互:传统线性注意力丢失负值信息,PolaLinearAtt 显式建模正负交互,使注意力更接近原始 softmax

  • 增强注意力稀疏性:通过幂函数缩放降低熵,恢复“尖锐”的注意力分布,提升模型判别力

  • 效率与性能兼顾:保持线性复杂度,显著提升分类、检测、分割任务的精度

  • 即插即用:可直接替换 Transformer 的自注意力模块,不改变整体结构

即插即用模块

				
					import torch
import torch.nnasnn

class PolaLinearAttention(nn.Module):
    def __init__(self, dim, num_patches, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1,
                 kernel_size=5, alpha=4):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.head_dim = head_dim

        self.qg = nn.Linear(dim, 2 * dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)

        self.dwc = nn.Conv2d(in_channels=head_dim, out_channels=head_dim, kernel_size=kernel_size,
                             groups=head_dim, padding=kernel_size // 2)
        
        self.power = nn.Parameter(torch.zeros(size=(1, self.num_heads, 1, self.head_dim)))
        self.alpha = alpha

        self.scale = nn.Parameter(torch.zeros(size=(1, 1, dim)))
        self.positional_encoding = nn.Parameter(torch.zeros(size=(1, num_patches // (sr_ratio * sr_ratio), dim)))
        print('Linear Attention sr_ratio{} f{} kernel{}'.
              format(sr_ratio, alpha, kernel_size))

    def forward(self, x, H, W):
        B, N, C = x.shape
        q, g = self.qg(x).reshape(B, N, 2, C).unbind(2)

        if self.sr_ratio > 1:
            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
            x_ = self.norm(x_)
            kv = self.kv(x_).reshape(B, -1, 2, C).permute(2, 0, 1, 3)
        else:
            kv = self.kv(x).reshape(B, -1, 2, C).permute(2, 0, 1, 3)
        k, v = kv[0], kv[1]
        n = k.shape[1]

        k = k + self.positional_encoding
        kernel_function = nn.ReLU()
        
        scale = nn.Softplus()(self.scale)
        power = 1 + self.alpha * torch.sigmoid(self.power)
        
        q = q / scale
        k = k / scale
        q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3).contiguous()
        k = k.reshape(B, n, self.num_heads, -1).permute(0, 2, 1, 3).contiguous()
        v = v.reshape(B, n, self.num_heads, -1).permute(0, 2, 1, 3).contiguous() 
        
        q_pos = kernel_function(q) ** power 
        q_neg = kernel_function(-q) ** power 
        k_pos = kernel_function(k) ** power 
        k_neg = kernel_function(-k) ** power 

        q_sim = torch.cat([q_pos, q_neg],dim=-1)
        q_opp = torch.cat([q_neg, q_pos],dim=-1)
        k = torch.cat([k_pos, k_neg],dim=-1)

        v1,v2 = torch.chunk(v,2,dim=-1)
        
        z = 1 / (q_sim @ k.mean(dim=-2, keepdim=True).transpose(-2, -1) + 1e-6)
        kv = (k.transpose(-2, -1) * (n ** -0.5)) @ (v1 * (n ** -0.5))
        x_sim = q_sim @ kv * z
        z = 1 / (q_opp @ k.mean(dim=-2, keepdim=True).transpose(-2, -1) + 1e-6)
        kv = (k.transpose(-2, -1) * (n ** -0.5)) @ (v2 * (n ** -0.5))
        x_opp = q_opp @ kv * z

        x = torch.cat([x_sim, x_opp],dim=-1)
        x = x.transpose(1, 2).reshape(B, N, C)

        if self.sr_ratio > 1:
            v = nn.functional.interpolate(v.transpose(-2, -1).reshape(B * self.num_heads, -1, n), size=N, mode='linear').reshape(B, self.num_heads, -1, N).transpose(-2, -1)
        
        v = v.reshape(B * self.num_heads, H, W, -1).permute(0, 3, 1, 2)
        v = self.dwc(v).reshape(B, C, N).permute(0, 2, 1)
        x = x + v
        x = x * g

        x = self.proj(x)
        x = self.proj_drop(x)

        returnx
    

if __name__ == "__main__":
    # 将模块移动到 GPU(如果可用)
    device = torch.device("cuda"if torch.cuda.is_available() else"cpu")
    # 创建测试输入张量 (batch_size, H * W, channels) / B N C
    x = torch.randn(1, 64, 128).to(device)
    # 初始化 pla 模块
    pla = PolaLinearAttention(dim=128, num_patches=64, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1,
                 kernel_size=5, alpha=4)
    print(pla)
    pla = pla.to(device)
    # 前向传播
    output = pla(x, H=8, W=8)
    
    # 打印输入和输出张量的形状
    print("输入张量形状:", x.shape)
    print("输出张量形状:", output.shape)