@@ -175,14 +175,14 @@ def get_multicond_prompt_list(prompts):
175175
176176class ComposableScheduledPromptConditioning :
177177 def __init__ (self , schedules , weight = 1.0 ):
178- self .schedules : list [ScheduledPromptConditioning ] = schedules
178+ self .schedules = schedules # : list[ScheduledPromptConditioning]
179179 self .weight : float = weight
180180
181181
182182class MulticondLearnedConditioning :
183183 def __init__ (self , shape , batch ):
184184 self .shape : tuple = shape # the shape field is needed to send this object to DDIM/PLMS
185- self .batch : list [list [ComposableScheduledPromptConditioning ]] = batch
185+ self .batch = batch # : list[list[ComposableScheduledPromptConditioning]]
186186
187187
188188def get_multicond_learned_conditioning (model , prompts , steps ) -> MulticondLearnedConditioning :
@@ -203,7 +203,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
203203 return MulticondLearnedConditioning (shape = (len (prompts ),), batch = res )
204204
205205
206- def reconstruct_cond_batch (c : list [list [ScheduledPromptConditioning ]], current_step ):
206+ def reconstruct_cond_batch (c , current_step ): # c: list[list[ScheduledPromptConditioning]]
207207 param = c [0 ][0 ].cond
208208 res = torch .zeros ((len (c ),) + param .shape , device = param .device , dtype = param .dtype )
209209 for i , cond_schedule in enumerate (c ):
0 commit comments