@@ -97,10 +97,26 @@ def get_schedule(prompt):
9797
9898
9999ScheduledPromptConditioning = namedtuple ("ScheduledPromptConditioning" , ["end_at_step" , "cond" ])
100- ScheduledPromptBatch = namedtuple ("ScheduledPromptBatch" , ["shape" , "schedules" ])
101100
102101
103102def get_learned_conditioning (model , prompts , steps ):
103+ """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
104+ and the sampling step at which this condition is to be replaced by the next one.
105+
106+ Input:
107+ (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
108+
109+ Output:
110+ [
111+ [
112+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
113+ ],
114+ [
115+ ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
116+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
117+ ]
118+ ]
119+ """
104120 res = []
105121
106122 prompt_schedules = get_learned_conditioning_prompt_schedules (prompts , steps )
@@ -123,13 +139,75 @@ def get_learned_conditioning(model, prompts, steps):
123139 cache [prompt ] = cond_schedule
124140 res .append (cond_schedule )
125141
126- return ScheduledPromptBatch ((len (prompts ),) + res [0 ][0 ].cond .shape , res )
142+ return res
143+
144+
145+ re_AND = re .compile (r"\bAND\b" )
146+ re_weight = re .compile (r"^(.*?)(?:\s*:\s*([-+]?\s*(?:\d+|\d*\.\d+)?))?\s*$" )
147+
148+
149+ def get_multicond_prompt_list (prompts ):
150+ res_indexes = []
151+
152+ prompt_flat_list = []
153+ prompt_indexes = {}
154+
155+ for prompt in prompts :
156+ subprompts = re_AND .split (prompt )
157+
158+ indexes = []
159+ for subprompt in subprompts :
160+ text , weight = re_weight .search (subprompt ).groups ()
161+
162+ weight = float (weight ) if weight is not None else 1.0
163+
164+ index = prompt_indexes .get (text , None )
165+ if index is None :
166+ index = len (prompt_flat_list )
167+ prompt_flat_list .append (text )
168+ prompt_indexes [text ] = index
169+
170+ indexes .append ((index , weight ))
171+
172+ res_indexes .append (indexes )
173+
174+ return res_indexes , prompt_flat_list , prompt_indexes
175+
176+
177+ class ComposableScheduledPromptConditioning :
178+ def __init__ (self , schedules , weight = 1.0 ):
179+ self .schedules : list [ScheduledPromptConditioning ] = schedules
180+ self .weight : float = weight
181+
182+
183+ class MulticondLearnedConditioning :
184+ def __init__ (self , shape , batch ):
185+ self .shape : tuple = shape # the shape field is needed to send this object to DDIM/PLMS
186+ self .batch : list [list [ComposableScheduledPromptConditioning ]] = batch
127187
128188
129- def reconstruct_cond_batch (c : ScheduledPromptBatch , current_step ):
130- param = c .schedules [0 ][0 ].cond
131- res = torch .zeros (c .shape , device = param .device , dtype = param .dtype )
132- for i , cond_schedule in enumerate (c .schedules ):
189+ def get_multicond_learned_conditioning (model , prompts , steps ) -> MulticondLearnedConditioning :
190+ """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
191+ For each prompt, the list is obtained by splitting the prompt using the AND separator.
192+
193+ https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
194+ """
195+
196+ res_indexes , prompt_flat_list , prompt_indexes = get_multicond_prompt_list (prompts )
197+
198+ learned_conditioning = get_learned_conditioning (model , prompt_flat_list , steps )
199+
200+ res = []
201+ for indexes in res_indexes :
202+ res .append ([ComposableScheduledPromptConditioning (learned_conditioning [i ], weight ) for i , weight in indexes ])
203+
204+ return MulticondLearnedConditioning (shape = (len (prompts ),), batch = res )
205+
206+
207+ def reconstruct_cond_batch (c : list [list [ScheduledPromptConditioning ]], current_step ):
208+ param = c [0 ][0 ].cond
209+ res = torch .zeros ((len (c ),) + param .shape , device = param .device , dtype = param .dtype )
210+ for i , cond_schedule in enumerate (c ):
133211 target_index = 0
134212 for current , (end_at , cond ) in enumerate (cond_schedule ):
135213 if current_step <= end_at :
@@ -140,6 +218,30 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
140218 return res
141219
142220
221+ def reconstruct_multicond_batch (c : MulticondLearnedConditioning , current_step ):
222+ param = c .batch [0 ][0 ].schedules [0 ].cond
223+
224+ tensors = []
225+ conds_list = []
226+
227+ for batch_no , composable_prompts in enumerate (c .batch ):
228+ conds_for_batch = []
229+
230+ for cond_index , composable_prompt in enumerate (composable_prompts ):
231+ target_index = 0
232+ for current , (end_at , cond ) in enumerate (composable_prompt .schedules ):
233+ if current_step <= end_at :
234+ target_index = current
235+ break
236+
237+ conds_for_batch .append ((len (tensors ), composable_prompt .weight ))
238+ tensors .append (composable_prompt .schedules [target_index ].cond )
239+
240+ conds_list .append (conds_for_batch )
241+
242+ return conds_list , torch .stack (tensors ).to (device = param .device , dtype = param .dtype )
243+
244+
143245re_attention = re .compile (r"""
144246\\\(|
145247\\\)|
0 commit comments