5353}
5454
5555
56+ def catenate_conds (conds ):
57+ if not isinstance (conds [0 ], dict ):
58+ return torch .cat (conds )
59+
60+ return {key : torch .cat ([x [key ] for x in conds ]) for key in conds [0 ].keys ()}
61+
62+
63+ def subscript_cond (cond , a , b ):
64+ if not isinstance (cond , dict ):
65+ return cond [a :b ]
66+
67+ return {key : vec [a :b ] for key , vec in cond .items ()}
68+
69+
70+ def pad_cond (tensor , repeats , empty ):
71+ if not isinstance (tensor , dict ):
72+ return torch .cat ([tensor , empty .repeat ((tensor .shape [0 ], repeats , 1 ))], axis = 1 )
73+
74+ tensor ['crossattn' ] = pad_cond (tensor ['crossattn' ], repeats , empty )
75+ return tensor
76+
77+
5678class CFGDenoiser (torch .nn .Module ):
5779 """
5880 Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
@@ -105,10 +127,13 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
105127
106128 if shared .sd_model .model .conditioning_key == "crossattn-adm" :
107129 image_uncond = torch .zeros_like (image_cond )
108- make_condition_dict = lambda c_crossattn , c_adm : {"c_crossattn" : c_crossattn , "c_adm" : c_adm }
130+ make_condition_dict = lambda c_crossattn , c_adm : {"c_crossattn" : [ c_crossattn ] , "c_adm" : c_adm }
109131 else :
110132 image_uncond = image_cond
111- make_condition_dict = lambda c_crossattn , c_concat : {"c_crossattn" : c_crossattn , "c_concat" : [c_concat ]}
133+ if isinstance (uncond , dict ):
134+ make_condition_dict = lambda c_crossattn , c_concat : {** c_crossattn , "c_concat" : [c_concat ]}
135+ else :
136+ make_condition_dict = lambda c_crossattn , c_concat : {"c_crossattn" : [c_crossattn ], "c_concat" : [c_concat ]}
112137
113138 if not is_edit_model :
114139 x_in = torch .cat ([torch .stack ([x [i ] for _ in range (n )]) for i , n in enumerate (repeats )] + [x ])
@@ -140,28 +165,28 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
140165 num_repeats = (tensor .shape [1 ] - uncond .shape [1 ]) // empty .shape [1 ]
141166
142167 if num_repeats < 0 :
143- tensor = torch . cat ([ tensor , empty . repeat (( tensor . shape [ 0 ], - num_repeats , 1 ))], axis = 1 )
168+ tensor = pad_cond ( tensor , - num_repeats , empty )
144169 self .padded_cond_uncond = True
145170 elif num_repeats > 0 :
146- uncond = torch . cat ([ uncond , empty . repeat (( uncond . shape [ 0 ], num_repeats , 1 ))], axis = 1 )
171+ uncond = pad_cond ( uncond , num_repeats , empty )
147172 self .padded_cond_uncond = True
148173
149174 if tensor .shape [1 ] == uncond .shape [1 ] or skip_uncond :
150175 if is_edit_model :
151- cond_in = torch . cat ([tensor , uncond , uncond ])
176+ cond_in = catenate_conds ([tensor , uncond , uncond ])
152177 elif skip_uncond :
153178 cond_in = tensor
154179 else :
155- cond_in = torch . cat ([tensor , uncond ])
180+ cond_in = catenate_conds ([tensor , uncond ])
156181
157182 if shared .batch_cond_uncond :
158- x_out = self .inner_model (x_in , sigma_in , cond = make_condition_dict ([ cond_in ] , image_cond_in ))
183+ x_out = self .inner_model (x_in , sigma_in , cond = make_condition_dict (cond_in , image_cond_in ))
159184 else :
160185 x_out = torch .zeros_like (x_in )
161186 for batch_offset in range (0 , x_out .shape [0 ], batch_size ):
162187 a = batch_offset
163188 b = a + batch_size
164- x_out [a :b ] = self .inner_model (x_in [a :b ], sigma_in [a :b ], cond = make_condition_dict ([ cond_in [a :b ] ], image_cond_in [a :b ]))
189+ x_out [a :b ] = self .inner_model (x_in [a :b ], sigma_in [a :b ], cond = make_condition_dict (cond_in [a :b ], image_cond_in [a :b ]))
165190 else :
166191 x_out = torch .zeros_like (x_in )
167192 batch_size = batch_size * 2 if shared .batch_cond_uncond else batch_size
@@ -170,14 +195,14 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
170195 b = min (a + batch_size , tensor .shape [0 ])
171196
172197 if not is_edit_model :
173- c_crossattn = [ tensor [ a : b ]]
198+ c_crossattn = subscript_cond ( tensor , a , b )
174199 else :
175200 c_crossattn = torch .cat ([tensor [a :b ]], uncond )
176201
177202 x_out [a :b ] = self .inner_model (x_in [a :b ], sigma_in [a :b ], cond = make_condition_dict (c_crossattn , image_cond_in [a :b ]))
178203
179204 if not skip_uncond :
180- x_out [- uncond .shape [0 ]:] = self .inner_model (x_in [- uncond .shape [0 ]:], sigma_in [- uncond .shape [0 ]:], cond = make_condition_dict ([ uncond ] , image_cond_in [- uncond .shape [0 ]:]))
205+ x_out [- uncond .shape [0 ]:] = self .inner_model (x_in [- uncond .shape [0 ]:], sigma_in [- uncond .shape [0 ]:], cond = make_condition_dict (uncond , image_cond_in [- uncond .shape [0 ]:]))
181206
182207 denoised_image_indexes = [x [0 ][0 ] for x in conds_list ]
183208 if skip_uncond :
0 commit comments