扩散模型Classifer Guidance中classifier的梯度是如何传给U-Net的?
这是一个常见的误区:Classifier Guidance中的梯度没有传给U-Net,只是对U-Net预测的score function进行了修改,这个过程是完全不涉及模型训练的那种梯度更新的,而是为了计算一个与score function形式匹配的「梯度项」来修改预测噪声。
你可以把Classifier Guidance做的事情理解成:在U-Net预测出下一步的去噪方向后,强行加了一个指向目标类别的「向量场」。
一、数学视角上,Classifier Guidance的梯度项是怎么得到的?有什么物理意义?
二、Classifier Guidance中的梯度 vs. 模型训练中的梯度?
三、推理中通常需要关闭梯度保存,代码上如何实现?
我们都知道,推理时开启torch.no_grad()把整个pipeline包起来以节省显存是一个基本操作,但guidance机制要求我们必须对当前的中间状态求导,也就引发了一个思考——代码层面上的Classifier Guidance是怎么实现的呢?
具体的思路是在no_grad的内部嵌套一个enable_grad,并配合detach操作构建一个临时的、只包含当前step计算图的微型梯度图。就像下面这样:
def cond_fn(x, t, y, classifier, scale):
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
logits = classifier(x_in, t)
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
selected = log_probs[range(len(logits)), y.view(-1)]
return torch.autograd.grad(selected.sum(), x_in)[0] * scale
# 假设在外层循环中:with torch.no_grad():
# model_output = unet(x, t)
# out = distribution_calculation(model_output)
# out包含 'mean', 'variance' 等键值
if cond_fn is not None:
gradient = cond_fn(x, t, y, classifier, scale)
new_mean = out["mean"].float() + out["variance"] * gradient.float()
out["mean"] = new_mean
–文章来源-alonzo
Comments NOTHING