Skip to content

Commit 90e911f

Browse files
raefuAUTOMATIC1111
authored andcommitted
prompt_parser: allow spaces in schedules, add test, log/ignore errors
Only build the parser once (at import time) instead of for each step. doctest is run by simply executing modules/prompt_parser.py
1 parent 1eb588c commit 90e911f

2 files changed

Lines changed: 95 additions & 54 deletions

File tree

‎modules/processing.py‎

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prom
8484
self.s_tmin = opts.s_tmin
8585
self.s_tmax = float('inf') # not representable as a standard ui option
8686
self.s_noise = opts.s_noise
87-
87+
8888
if not seed_enable_extras:
8989
self.subseed = -1
9090
self.subseed_strength = 0
@@ -296,7 +296,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
296296
assert(len(p.prompt) > 0)
297297
else:
298298
assert p.prompt is not None
299-
299+
300300
devices.torch_gc()
301301

302302
seed = get_fixed_seed(p.seed)
@@ -359,8 +359,8 @@ def infotext(iteration=0, position_in_batch=0):
359359
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
360360
#c = p.sd_model.get_learned_conditioning(prompts)
361361
with devices.autocast():
362-
uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
363-
c = prompt_parser.get_learned_conditioning(prompts, p.steps)
362+
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
363+
c = prompt_parser.get_learned_conditioning(shared.sd_model, prompts, p.steps)
364364

365365
if len(model_hijack.comments) > 0:
366366
for comment in model_hijack.comments:
@@ -527,7 +527,7 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
527527
# GC now before running the next img2img to prevent running out of memory
528528
x = None
529529
devices.torch_gc()
530-
530+
531531
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
532532

533533
return samples

‎modules/prompt_parser.py‎

Lines changed: 90 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import re
22
from collections import namedtuple
3-
import torch
4-
from lark import Lark, Transformer, Visitor
5-
import functools
63

7-
import modules.shared as shared
4+
import lark
85

96
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
107
# will be represented with prompt_schedule like this (assuming steps=100):
@@ -14,25 +11,48 @@
1411
# [75, 'fantasy landscape with a lake and an oak in background masterful']
1512
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
1613

14+
schedule_parser = lark.Lark(r"""
15+
!start: (prompt | /[][():]/+)*
16+
prompt: (emphasized | scheduled | plain | WHITESPACE)*
17+
!emphasized: "(" prompt ")"
18+
| "(" prompt ":" prompt ")"
19+
| "[" prompt "]"
20+
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
21+
WHITESPACE: /\s+/
22+
plain: /([^\\\[\]():]|\\.)+/
23+
%import common.SIGNED_NUMBER -> NUMBER
24+
""")
1725

1826
def get_learned_conditioning_prompt_schedules(prompts, steps):
19-
grammar = r"""
20-
start: prompt
21-
prompt: (emphasized | scheduled | weighted | plain)*
22-
!emphasized: "(" prompt ")"
23-
| "(" prompt ":" prompt ")"
24-
| "[" prompt "]"
25-
scheduled: "[" (prompt ":")? prompt ":" NUMBER "]"
26-
!weighted: "{" weighted_item ("|" weighted_item)* "}"
27-
!weighted_item: prompt (":" prompt)?
28-
plain: /([^\\\[\](){}:|]|\\.)+/
29-
%import common.SIGNED_NUMBER -> NUMBER
3027
"""
31-
parser = Lark(grammar, parser='lalr')
28+
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
29+
>>> g("test")
30+
[[10, 'test']]
31+
>>> g("a [b:3]")
32+
[[3, 'a '], [10, 'a b']]
33+
>>> g("a [b: 3]")
34+
[[3, 'a '], [10, 'a b']]
35+
>>> g("a [[[b]]:2]")
36+
[[2, 'a '], [10, 'a [[b]]']]
37+
>>> g("[(a:2):3]")
38+
[[3, ''], [10, '(a:2)']]
39+
>>> g("a [b : c : 1] d")
40+
[[1, 'a b d'], [10, 'a c d']]
41+
>>> g("a[b:[c:d:2]:1]e")
42+
[[1, 'abe'], [2, 'ace'], [10, 'ade']]
43+
>>> g("a [unbalanced")
44+
[[10, 'a [unbalanced']]
45+
>>> g("a [b:.5] c")
46+
[[5, 'a c'], [10, 'a b c']]
47+
>>> g("a [{b|d{:.5] c") # not handling this right now
48+
[[5, 'a c'], [10, 'a {b|d{ c']]
49+
>>> g("((a][:b:c [d:3]")
50+
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
51+
"""
3252

3353
def collect_steps(steps, tree):
3454
l = [steps]
35-
class CollectSteps(Visitor):
55+
class CollectSteps(lark.Visitor):
3656
def scheduled(self, tree):
3757
tree.children[-1] = float(tree.children[-1])
3858
if tree.children[-1] < 1:
@@ -43,30 +63,33 @@ def scheduled(self, tree):
4363
return sorted(set(l))
4464

4565
def at_step(step, tree):
46-
class AtStep(Transformer):
66+
class AtStep(lark.Transformer):
4767
def scheduled(self, args):
48-
if len(args) == 2:
49-
before, after, when = (), *args
50-
else:
51-
before, after, when = args
52-
yield before if step <= when else after
68+
before, after, _, when = args
69+
yield before or () if step <= when else after
5370
def start(self, args):
5471
def flatten(x):
5572
if type(x) == str:
5673
yield x
5774
else:
5875
for gen in x:
5976
yield from flatten(gen)
60-
return ''.join(flatten(args[0]))
77+
return ''.join(flatten(args))
6178
def plain(self, args):
6279
yield args[0].value
6380
def __default__(self, data, children, meta):
6481
for child in children:
6582
yield from child
6683
return AtStep().transform(tree)
67-
84+
6885
def get_schedule(prompt):
69-
tree = parser.parse(prompt)
86+
try:
87+
tree = schedule_parser.parse(prompt)
88+
except lark.exceptions.LarkError as e:
89+
if 0:
90+
import traceback
91+
traceback.print_exc()
92+
return [[steps, prompt]]
7093
return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
7194

7295
promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
@@ -77,8 +100,7 @@ def get_schedule(prompt):
77100
ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"])
78101

79102

80-
def get_learned_conditioning(prompts, steps):
81-
103+
def get_learned_conditioning(model, prompts, steps):
82104
res = []
83105

84106
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
@@ -92,7 +114,7 @@ def get_learned_conditioning(prompts, steps):
92114
continue
93115

94116
texts = [x[1] for x in prompt_schedule]
95-
conds = shared.sd_model.get_learned_conditioning(texts)
117+
conds = model.get_learned_conditioning(texts)
96118

97119
cond_schedule = []
98120
for i, (end_at_step, text) in enumerate(prompt_schedule):
@@ -105,12 +127,13 @@ def get_learned_conditioning(prompts, steps):
105127

106128

107129
def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
108-
res = torch.zeros(c.shape, device=shared.device, dtype=next(shared.sd_model.parameters()).dtype)
130+
param = c.schedules[0][0].cond
131+
res = torch.zeros(c.shape, device=param.device, dtype=param.dtype)
109132
for i, cond_schedule in enumerate(c.schedules):
110133
target_index = 0
111-
for curret_index, (end_at, cond) in enumerate(cond_schedule):
134+
for current, (end_at, cond) in enumerate(cond_schedule):
112135
if current_step <= end_at:
113-
target_index = curret_index
136+
target_index = current
114137
break
115138
res[i] = cond_schedule[target_index].cond
116139

@@ -148,23 +171,26 @@ def parse_prompt_attention(text):
148171
\\ - literal character '\'
149172
anything else - just text
150173
151-
Example:
152-
153-
'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).'
154-
155-
produces:
156-
157-
[
158-
['a ', 1.0],
159-
['house', 1.5730000000000004],
160-
[' ', 1.1],
161-
['on', 1.0],
162-
[' a ', 1.1],
163-
['hill', 0.55],
164-
[', sun, ', 1.1],
165-
['sky', 1.4641000000000006],
166-
['.', 1.1]
167-
]
174+
>>> parse_prompt_attention('normal text')
175+
[['normal text', 1.0]]
176+
>>> parse_prompt_attention('an (important) word')
177+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
178+
>>> parse_prompt_attention('(unbalanced')
179+
[['unbalanced', 1.1]]
180+
>>> parse_prompt_attention('\(literal\]')
181+
[['(literal]', 1.0]]
182+
>>> parse_prompt_attention('(unnecessary)(parens)')
183+
[['unnecessaryparens', 1.1]]
184+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
185+
[['a ', 1.0],
186+
['house', 1.5730000000000004],
187+
[' ', 1.1],
188+
['on', 1.0],
189+
[' a ', 1.1],
190+
['hill', 0.55],
191+
[', sun, ', 1.1],
192+
['sky', 1.4641000000000006],
193+
['.', 1.1]]
168194
"""
169195

170196
res = []
@@ -206,4 +232,19 @@ def multiply_range(start_position, multiplier):
206232
if len(res) == 0:
207233
res = [["", 1.0]]
208234

235+
# merge runs of identical weights
236+
i = 0
237+
while i + 1 < len(res):
238+
if res[i][1] == res[i + 1][1]:
239+
res[i][0] += res[i + 1][0]
240+
res.pop(i + 1)
241+
else:
242+
i += 1
243+
209244
return res
245+
246+
if __name__ == "__main__":
247+
import doctest
248+
doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
249+
else:
250+
import torch # doctest faster

0 commit comments

Comments
 (0)