11import re
22from collections import namedtuple
3- import torch
4- from lark import Lark , Transformer , Visitor
5- import functools
63
7- import modules . shared as shared
4+ import lark
85
96# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
107# will be represented with prompt_schedule like this (assuming steps=100):
1411# [75, 'fantasy landscape with a lake and an oak in background masterful']
1512# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
1613
14+ schedule_parser = lark .Lark (r"""
15+ !start: (prompt | /[][():]/+)*
16+ prompt: (emphasized | scheduled | plain | WHITESPACE)*
17+ !emphasized: "(" prompt ")"
18+ | "(" prompt ":" prompt ")"
19+ | "[" prompt "]"
20+ scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
21+ WHITESPACE: /\s+/
22+ plain: /([^\\\[\]():]|\\.)+/
23+ %import common.SIGNED_NUMBER -> NUMBER
24+ """ )
1725
1826def get_learned_conditioning_prompt_schedules (prompts , steps ):
19- grammar = r"""
20- start: prompt
21- prompt: (emphasized | scheduled | weighted | plain)*
22- !emphasized: "(" prompt ")"
23- | "(" prompt ":" prompt ")"
24- | "[" prompt "]"
25- scheduled: "[" (prompt ":")? prompt ":" NUMBER "]"
26- !weighted: "{" weighted_item ("|" weighted_item)* "}"
27- !weighted_item: prompt (":" prompt)?
28- plain: /([^\\\[\](){}:|]|\\.)+/
29- %import common.SIGNED_NUMBER -> NUMBER
3027 """
31- parser = Lark (grammar , parser = 'lalr' )
28+ >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
29+ >>> g("test")
30+ [[10, 'test']]
31+ >>> g("a [b:3]")
32+ [[3, 'a '], [10, 'a b']]
33+ >>> g("a [b: 3]")
34+ [[3, 'a '], [10, 'a b']]
35+ >>> g("a [[[b]]:2]")
36+ [[2, 'a '], [10, 'a [[b]]']]
37+ >>> g("[(a:2):3]")
38+ [[3, ''], [10, '(a:2)']]
39+ >>> g("a [b : c : 1] d")
40+ [[1, 'a b d'], [10, 'a c d']]
41+ >>> g("a[b:[c:d:2]:1]e")
42+ [[1, 'abe'], [2, 'ace'], [10, 'ade']]
43+ >>> g("a [unbalanced")
44+ [[10, 'a [unbalanced']]
45+ >>> g("a [b:.5] c")
46+ [[5, 'a c'], [10, 'a b c']]
47+ >>> g("a [{b|d{:.5] c") # not handling this right now
48+ [[5, 'a c'], [10, 'a {b|d{ c']]
49+ >>> g("((a][:b:c [d:3]")
50+ [[3, '((a][:b:c '], [10, '((a][:b:c d']]
51+ """
3252
3353 def collect_steps (steps , tree ):
3454 l = [steps ]
35- class CollectSteps (Visitor ):
55+ class CollectSteps (lark . Visitor ):
3656 def scheduled (self , tree ):
3757 tree .children [- 1 ] = float (tree .children [- 1 ])
3858 if tree .children [- 1 ] < 1 :
@@ -43,30 +63,33 @@ def scheduled(self, tree):
4363 return sorted (set (l ))
4464
4565 def at_step (step , tree ):
46- class AtStep (Transformer ):
66+ class AtStep (lark . Transformer ):
4767 def scheduled (self , args ):
48- if len (args ) == 2 :
49- before , after , when = (), * args
50- else :
51- before , after , when = args
52- yield before if step <= when else after
68+ before , after , _ , when = args
69+ yield before or () if step <= when else after
5370 def start (self , args ):
5471 def flatten (x ):
5572 if type (x ) == str :
5673 yield x
5774 else :
5875 for gen in x :
5976 yield from flatten (gen )
60- return '' .join (flatten (args [ 0 ] ))
77+ return '' .join (flatten (args ))
6178 def plain (self , args ):
6279 yield args [0 ].value
6380 def __default__ (self , data , children , meta ):
6481 for child in children :
6582 yield from child
6683 return AtStep ().transform (tree )
67-
84+
6885 def get_schedule (prompt ):
69- tree = parser .parse (prompt )
86+ try :
87+ tree = schedule_parser .parse (prompt )
88+ except lark .exceptions .LarkError as e :
89+ if 0 :
90+ import traceback
91+ traceback .print_exc ()
92+ return [[steps , prompt ]]
7093 return [[t , at_step (t , tree )] for t in collect_steps (steps , tree )]
7194
7295 promptdict = {prompt : get_schedule (prompt ) for prompt in set (prompts )}
@@ -77,8 +100,7 @@ def get_schedule(prompt):
77100ScheduledPromptBatch = namedtuple ("ScheduledPromptBatch" , ["shape" , "schedules" ])
78101
79102
80- def get_learned_conditioning (prompts , steps ):
81-
103+ def get_learned_conditioning (model , prompts , steps ):
82104 res = []
83105
84106 prompt_schedules = get_learned_conditioning_prompt_schedules (prompts , steps )
@@ -92,7 +114,7 @@ def get_learned_conditioning(prompts, steps):
92114 continue
93115
94116 texts = [x [1 ] for x in prompt_schedule ]
95- conds = shared . sd_model .get_learned_conditioning (texts )
117+ conds = model .get_learned_conditioning (texts )
96118
97119 cond_schedule = []
98120 for i , (end_at_step , text ) in enumerate (prompt_schedule ):
@@ -105,12 +127,13 @@ def get_learned_conditioning(prompts, steps):
105127
106128
107129def reconstruct_cond_batch (c : ScheduledPromptBatch , current_step ):
108- res = torch .zeros (c .shape , device = shared .device , dtype = next (shared .sd_model .parameters ()).dtype )
130+ param = c .schedules [0 ][0 ].cond
131+ res = torch .zeros (c .shape , device = param .device , dtype = param .dtype )
109132 for i , cond_schedule in enumerate (c .schedules ):
110133 target_index = 0
111- for curret_index , (end_at , cond ) in enumerate (cond_schedule ):
134+ for current , (end_at , cond ) in enumerate (cond_schedule ):
112135 if current_step <= end_at :
113- target_index = curret_index
136+ target_index = current
114137 break
115138 res [i ] = cond_schedule [target_index ].cond
116139
@@ -148,23 +171,26 @@ def parse_prompt_attention(text):
148171 \\ - literal character '\'
149172 anything else - just text
150173
151- Example:
152-
153- 'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).'
154-
155- produces:
156-
157- [
158- ['a ', 1.0],
159- ['house', 1.5730000000000004],
160- [' ', 1.1],
161- ['on', 1.0],
162- [' a ', 1.1],
163- ['hill', 0.55],
164- [', sun, ', 1.1],
165- ['sky', 1.4641000000000006],
166- ['.', 1.1]
167- ]
174+ >>> parse_prompt_attention('normal text')
175+ [['normal text', 1.0]]
176+ >>> parse_prompt_attention('an (important) word')
177+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
178+ >>> parse_prompt_attention('(unbalanced')
179+ [['unbalanced', 1.1]]
180+ >>> parse_prompt_attention('\(literal\]')
181+ [['(literal]', 1.0]]
182+ >>> parse_prompt_attention('(unnecessary)(parens)')
183+ [['unnecessaryparens', 1.1]]
184+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
185+ [['a ', 1.0],
186+ ['house', 1.5730000000000004],
187+ [' ', 1.1],
188+ ['on', 1.0],
189+ [' a ', 1.1],
190+ ['hill', 0.55],
191+ [', sun, ', 1.1],
192+ ['sky', 1.4641000000000006],
193+ ['.', 1.1]]
168194 """
169195
170196 res = []
@@ -206,4 +232,19 @@ def multiply_range(start_position, multiplier):
206232 if len (res ) == 0 :
207233 res = [["" , 1.0 ]]
208234
235+ # merge runs of identical weights
236+ i = 0
237+ while i + 1 < len (res ):
238+ if res [i ][1 ] == res [i + 1 ][1 ]:
239+ res [i ][0 ] += res [i + 1 ][0 ]
240+ res .pop (i + 1 )
241+ else :
242+ i += 1
243+
209244 return res
245+
246+ if __name__ == "__main__" :
247+ import doctest
248+ doctest .testmod (optionflags = doctest .NORMALIZE_WHITESPACE )
249+ else :
250+ import torch # doctest faster
0 commit comments