Diffusion微调报错:RuntimeError: One of the differentiated Tensors does not require grad
由于在微调设置时设置了不微调层的require_grad=False,导致output_grads=None,进而导致torch.autograd.grad失效。解决办法:设置flag=False。最近在开展Diffusion Model模型微调的相关工作时,设置微调层后反传梯度多次遇到以下报错。代码:OpenAI-UNetModel。网路上相关内容较少,特此记录。
·
最近在开展Diffusion Model模型微调的相关工作时,设置微调层后反传梯度多次遇到以下报错
RuntimeError: One of the differentiated Tensors does not require grad
网络上相关内容较少,特此记录。
代码:OpenAI-UNetModel
Bug定位过程:
- 逐层设置微调,发现后面的层可以正常微调,只有部分层设置微调时会报错;
- 找到出现该报错的层,查看代码,发现前向函数会调用名为
checkpoint()
的函数:
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads
由于在微调设置时设置了不微调层的require_grad=False,导致output_grads=None,进而导致torch.autograd.grad报错。解决办法:设置flag=False。
更多推荐
已为社区贡献1条内容
所有评论(0)