Skip to content

Commit 2f1b61d

Browse files
guaneecAUTOMATIC1111
authored andcommitted
Allow nested structures inside schedules
1 parent 6c6ae28 commit 2f1b61d

3 files changed

Lines changed: 55 additions & 66 deletions

File tree

‎modules/prompt_parser.py‎

Lines changed: 53 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,11 @@
11
import re
22
from collections import namedtuple
33
import torch
4+
from lark import Lark, Transformer, Visitor
5+
import functools
46

57
import modules.shared as shared
68

7-
re_prompt = re.compile(r'''
8-
(.*?)
9-
\[
10-
([^]:]+):
11-
(?:([^]:]*):)?
12-
([0-9]*\.?[0-9]+)
13-
]
14-
|
15-
(.+)
16-
''', re.X)
17-
189
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
1910
# will be represented with prompt_schedule like this (assuming steps=100):
2011
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
@@ -25,61 +16,57 @@
2516

2617

2718
def get_learned_conditioning_prompt_schedules(prompts, steps):
28-
res = []
29-
cache = {}
30-
31-
for prompt in prompts:
32-
prompt_schedule: list[list[str | int]] = [[steps, ""]]
33-
34-
cached = cache.get(prompt, None)
35-
if cached is not None:
36-
res.append(cached)
37-
continue
38-
39-
for m in re_prompt.finditer(prompt):
40-
plaintext = m.group(1) if m.group(5) is None else m.group(5)
41-
concept_from = m.group(2)
42-
concept_to = m.group(3)
43-
if concept_to is None:
44-
concept_to = concept_from
45-
concept_from = ""
46-
swap_position = float(m.group(4)) if m.group(4) is not None else None
47-
48-
if swap_position is not None:
49-
if swap_position < 1:
50-
swap_position = swap_position * steps
51-
swap_position = int(min(swap_position, steps))
52-
53-
swap_index = None
54-
found_exact_index = False
55-
for i in range(len(prompt_schedule)):
56-
end_step = prompt_schedule[i][0]
57-
prompt_schedule[i][1] += plaintext
58-
59-
if swap_position is not None and swap_index is None:
60-
if swap_position == end_step:
61-
swap_index = i
62-
found_exact_index = True
63-
64-
if swap_position < end_step:
65-
swap_index = i
66-
67-
if swap_index is not None:
68-
if not found_exact_index:
69-
prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]])
70-
71-
for i in range(len(prompt_schedule)):
72-
end_step = prompt_schedule[i][0]
73-
must_replace = swap_position < end_step
74-
75-
prompt_schedule[i][1] += concept_to if must_replace else concept_from
76-
77-
res.append(prompt_schedule)
78-
cache[prompt] = prompt_schedule
79-
#for t in prompt_schedule:
80-
# print(t)
81-
82-
return res
19+
grammar = r"""
20+
start: prompt
21+
prompt: (emphasized | scheduled | weighted | plain)*
22+
!emphasized: "(" prompt ")"
23+
| "(" prompt ":" prompt ")"
24+
| "[" prompt "]"
25+
scheduled: "[" (prompt ":")? prompt ":" NUMBER "]"
26+
!weighted: "{" weighted_item ("|" weighted_item)* "}"
27+
!weighted_item: prompt (":" prompt)?
28+
plain: /([^\\\[\](){}:|]|\\.)+/
29+
%import common.SIGNED_NUMBER -> NUMBER
30+
"""
31+
parser = Lark(grammar, parser='lalr')
32+
def collect_steps(steps, tree):
33+
l = [steps]
34+
class CollectSteps(Visitor):
35+
def scheduled(self, tree):
36+
tree.children[-1] = float(tree.children[-1])
37+
if tree.children[-1] < 1:
38+
tree.children[-1] *= steps
39+
tree.children[-1] = min(steps, int(tree.children[-1]))
40+
l.append(tree.children[-1])
41+
CollectSteps().visit(tree)
42+
return sorted(set(l))
43+
def at_step(step, tree):
44+
class AtStep(Transformer):
45+
def scheduled(self, args):
46+
if len(args) == 2:
47+
before, after, when = (), *args
48+
else:
49+
before, after, when = args
50+
yield before if step <= when else after
51+
def start(self, args):
52+
def flatten(x):
53+
if type(x) == str:
54+
yield x
55+
else:
56+
for gen in x:
57+
yield from flatten(gen)
58+
return ''.join(flatten(args[0]))
59+
def plain(self, args):
60+
yield args[0].value
61+
def __default__(self, data, children, meta):
62+
for child in children:
63+
yield from child
64+
return AtStep().transform(tree)
65+
@functools.cache
66+
def get_schedule(prompt):
67+
tree = parser.parse(prompt)
68+
return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
69+
return [get_schedule(prompt) for prompt in prompts]
8370

8471

8572
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])

‎requirements.txt‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ clean-fid
2222
resize-right
2323
torchdiffeq
2424
kornia
25+
lark

‎requirements_versions.txt‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ clean-fid==0.1.29
2121
resize-right==0.0.2
2222
torchdiffeq==0.2.3
2323
kornia==0.6.7
24+
lark==1.1.2

0 commit comments

Comments
 (0)