扩散模型Classifer Guidance中classifier的梯度是如何传给U-Net的?

酥酥 发布于 2026-05-21 27 次阅读


扩散模型Classifer Guidance中classifier的梯度是如何传给U-Net的?

这是一个常见的误区:Classifier Guidance中的梯度没有传给U-Net,只是对U-Net预测的score function进行了修改,这个过程是完全不涉及模型训练的那种梯度更新的,而是为了计算一个与score function形式匹配的「梯度项」来修改预测噪声。

你可以把Classifier Guidance做的事情理解成:在U-Net预测出下一步的去噪方向后,强行加了一个指向目标类别的「向量场」。

一、数学视角上,Classifier Guidance的梯度项是怎么得到的?有什么物理意义?

二、Classifier Guidance中的梯度 vs. 模型训练中的梯度?

三、推理中通常需要关闭梯度保存,代码上如何实现?

我们都知道,推理时开启torch.no_grad()把整个pipeline包起来以节省显存是一个基本操作,但guidance机制要求我们必须对当前的中间状态求导,也就引发了一个思考——代码层面上的Classifier Guidance是怎么实现的呢?

具体的思路是在no_grad的内部嵌套一个enable_grad,并配合detach操作构建一个临时的、只包含当前step计算图的微型梯度图。就像下面这样:

				
					
def cond_fn(x, t, y, classifier, scale):
    with torch.enable_grad():
        x_in = x.detach().requires_grad_(True)
        logits = classifier(x_in, t)
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        selected = log_probs[range(len(logits)), y.view(-1)]
        return torch.autograd.grad(selected.sum(), x_in)[0] * scale
				
			
				
					# 假设在外层循环中:with torch.no_grad():
# model_output = unet(x, t)
# out = distribution_calculation(model_output) 
# out包含 'mean', 'variance' 等键值
 
if cond_fn is not None:
    gradient = cond_fn(x, t, y, classifier, scale)
    new_mean = out["mean"].float() + out["variance"] * gradient.float()
    out["mean"] = new_mean
				
			

–文章来源-alonzo