Vision Transformer中的Attention热力图可视化原理是什么?

酥酥 发布于 7 天前 20 次阅读


Vision Transformer中的Attention热力图可视化原理是什么?

一、ViT的注意力可视化在做一件什么事情?

实际上ViT的attention map可视化本质是提取 [CLS] token在最后一个Transformer block中对所有image patch的权重整合。

这并不是像Grad-CAM那样通过梯度回传计算出来的saliency map,在机制上两者之间是有明显区别的,ViT的注意力可视化和attention的计算过程密切相关。

在ViT的架构中,[CLS] token负责汇聚所有的image patches信息,用于最终分类并输出class label对应的logits,因此,它在最后一层的attention weights直接反映了:模型为了做classification关注哪些image patches更多。

而可视化这一过程的逻辑在于:我们只需要拿到这个权重向量,把它从一维序列还原回二维的grid结构,再插值上采样放大覆盖在原图上,就能得到直观的heatmap——这也是我们最为常见的ViT注意力的可视化形式。

二、代码层面上,注意力可视化过程是如何体现的?

下面是一段基于Hugging Face transformers库里的标准化可视化代码,展示了从加载模型、预处理图片、提取attention weights到最终可视化的完整流程。

				
					import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from transformers import ViTImageProcessor, ViTModel
 
# 1. Prepare model and image
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTModel.from_pretrained(model_name)
image = Image.open("test.jpg") # Replace with your image path
 
# 2. Inference and get attention
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
    # Key point: set output_attentions=True to retrieve intermediate states
    outputs = model(**inputs, output_attentions=True)
 
# 3. Extract Attention weights
last_layer_attentions = outputs.attentions[-1][0] 
 
# 4. Process weights
attentions_mean = torch.mean(last_layer_attentions, dim=0)
patch_attentions = attentions_mean[0, 1:] 
 
# 5. Reshape back to 2D spatial grid
grid_size = int(np.sqrt(patch_attentions.shape[0]))
attentions_grid = patch_attentions.reshape(grid_size, grid_size)
 
# 6. Visualization
attentions_grid = torch.nn.functional.interpolate(
    attentions_grid.unsqueeze(0).unsqueeze(0), 
    size=image.size[::-1], 
    mode="bilinear"
).squeeze().numpy()
 
plt.imshow(image)
plt.imshow(attentions_grid, cmap='jet', alpha=0.5)
plt.axis('off')
plt.show()
				
			

–文章来源Alonze