controlnet源码一点小迷思
frc99 中本葱

迷之思考1 control相加位置

ControlNet
在论文以及webui实现中,controlnet的输出都是相加在upblock和midblock的输入位置,然而在diffusers的实现中,在运算完downblock后,将down_block_additional_residuals加在了downblock的输出位置,虽然作用一致,但给我带了理解上的迷惑。
以下是https://github.com/huggingface/diffusers/blob/v0.27.2/src/diffusers/pipelines/controlnet/pipeline_controlnet.py#L1252

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
down_block_res_samples, mid_block_res_sample = self.controlnet(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=image,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
return_dict=False,
)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]

以下是https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_condition.py#L1208

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# downblock
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
# For t2i-adapter CrossAttnDownBlock2D
additional_residuals = {}
if is_adapter and len(down_intrablock_additional_residuals) > 0:
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)

sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
**additional_residuals,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
if is_adapter and len(down_intrablock_additional_residuals) > 0:
sample += down_intrablock_additional_residuals.pop(0)

down_block_res_samples += res_samples
# 运算完downblock后
if is_controlnet:
new_down_block_res_samples = ()

for down_block_res_sample, down_block_additional_residual in zip(
down_block_res_samples, down_block_additional_residuals
):
# 此处做了相加
down_block_res_sample = down_block_res_sample + down_block_additional_residual
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)

down_block_res_samples = new_down_block_res_samples
由 Hexo 驱动 & 主题 Keep