|
1 | 1 | import re |
2 | 2 | from collections import namedtuple |
3 | 3 | import torch |
| 4 | +from lark import Lark, Transformer, Visitor |
| 5 | +import functools |
4 | 6 |
|
5 | 7 | import modules.shared as shared |
6 | 8 |
|
7 | | -re_prompt = re.compile(r''' |
8 | | -(.*?) |
9 | | -\[ |
10 | | - ([^]:]+): |
11 | | - (?:([^]:]*):)? |
12 | | - ([0-9]*\.?[0-9]+) |
13 | | -] |
14 | | -| |
15 | | -(.+) |
16 | | -''', re.X) |
17 | | - |
18 | 9 | # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" |
19 | 10 | # will be represented with prompt_schedule like this (assuming steps=100): |
20 | 11 | # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] |
|
25 | 16 |
|
26 | 17 |
|
27 | 18 | def get_learned_conditioning_prompt_schedules(prompts, steps): |
28 | | - res = [] |
29 | | - cache = {} |
30 | | - |
31 | | - for prompt in prompts: |
32 | | - prompt_schedule: list[list[str | int]] = [[steps, ""]] |
33 | | - |
34 | | - cached = cache.get(prompt, None) |
35 | | - if cached is not None: |
36 | | - res.append(cached) |
37 | | - continue |
38 | | - |
39 | | - for m in re_prompt.finditer(prompt): |
40 | | - plaintext = m.group(1) if m.group(5) is None else m.group(5) |
41 | | - concept_from = m.group(2) |
42 | | - concept_to = m.group(3) |
43 | | - if concept_to is None: |
44 | | - concept_to = concept_from |
45 | | - concept_from = "" |
46 | | - swap_position = float(m.group(4)) if m.group(4) is not None else None |
47 | | - |
48 | | - if swap_position is not None: |
49 | | - if swap_position < 1: |
50 | | - swap_position = swap_position * steps |
51 | | - swap_position = int(min(swap_position, steps)) |
52 | | - |
53 | | - swap_index = None |
54 | | - found_exact_index = False |
55 | | - for i in range(len(prompt_schedule)): |
56 | | - end_step = prompt_schedule[i][0] |
57 | | - prompt_schedule[i][1] += plaintext |
58 | | - |
59 | | - if swap_position is not None and swap_index is None: |
60 | | - if swap_position == end_step: |
61 | | - swap_index = i |
62 | | - found_exact_index = True |
63 | | - |
64 | | - if swap_position < end_step: |
65 | | - swap_index = i |
66 | | - |
67 | | - if swap_index is not None: |
68 | | - if not found_exact_index: |
69 | | - prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]]) |
70 | | - |
71 | | - for i in range(len(prompt_schedule)): |
72 | | - end_step = prompt_schedule[i][0] |
73 | | - must_replace = swap_position < end_step |
74 | | - |
75 | | - prompt_schedule[i][1] += concept_to if must_replace else concept_from |
76 | | - |
77 | | - res.append(prompt_schedule) |
78 | | - cache[prompt] = prompt_schedule |
79 | | - #for t in prompt_schedule: |
80 | | - # print(t) |
81 | | - |
82 | | - return res |
| 19 | + grammar = r""" |
| 20 | + start: prompt |
| 21 | + prompt: (emphasized | scheduled | weighted | plain)* |
| 22 | + !emphasized: "(" prompt ")" |
| 23 | + | "(" prompt ":" prompt ")" |
| 24 | + | "[" prompt "]" |
| 25 | + scheduled: "[" (prompt ":")? prompt ":" NUMBER "]" |
| 26 | + !weighted: "{" weighted_item ("|" weighted_item)* "}" |
| 27 | + !weighted_item: prompt (":" prompt)? |
| 28 | + plain: /([^\\\[\](){}:|]|\\.)+/ |
| 29 | + %import common.SIGNED_NUMBER -> NUMBER |
| 30 | + """ |
| 31 | + parser = Lark(grammar, parser='lalr') |
| 32 | + def collect_steps(steps, tree): |
| 33 | + l = [steps] |
| 34 | + class CollectSteps(Visitor): |
| 35 | + def scheduled(self, tree): |
| 36 | + tree.children[-1] = float(tree.children[-1]) |
| 37 | + if tree.children[-1] < 1: |
| 38 | + tree.children[-1] *= steps |
| 39 | + tree.children[-1] = min(steps, int(tree.children[-1])) |
| 40 | + l.append(tree.children[-1]) |
| 41 | + CollectSteps().visit(tree) |
| 42 | + return sorted(set(l)) |
| 43 | + def at_step(step, tree): |
| 44 | + class AtStep(Transformer): |
| 45 | + def scheduled(self, args): |
| 46 | + if len(args) == 2: |
| 47 | + before, after, when = (), *args |
| 48 | + else: |
| 49 | + before, after, when = args |
| 50 | + yield before if step <= when else after |
| 51 | + def start(self, args): |
| 52 | + def flatten(x): |
| 53 | + if type(x) == str: |
| 54 | + yield x |
| 55 | + else: |
| 56 | + for gen in x: |
| 57 | + yield from flatten(gen) |
| 58 | + return ''.join(flatten(args[0])) |
| 59 | + def plain(self, args): |
| 60 | + yield args[0].value |
| 61 | + def __default__(self, data, children, meta): |
| 62 | + for child in children: |
| 63 | + yield from child |
| 64 | + return AtStep().transform(tree) |
| 65 | + @functools.cache |
| 66 | + def get_schedule(prompt): |
| 67 | + tree = parser.parse(prompt) |
| 68 | + return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)] |
| 69 | + return [get_schedule(prompt) for prompt in prompts] |
83 | 70 |
|
84 | 71 |
|
85 | 72 | ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) |
|
0 commit comments