Skip to content

Commit 8e2aeee

Browse files
committed
add BREAK keyword to end current text chunk and start the next
1 parent 205991d commit 8e2aeee

2 files changed

Lines changed: 19 additions & 5 deletions

File tree

‎modules/prompt_parser.py‎

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
274274
:
275275
""", re.X)
276276

277+
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
277278

278279
def parse_prompt_attention(text):
279280
"""
@@ -339,7 +340,11 @@ def multiply_range(start_position, multiplier):
339340
elif text == ']' and len(square_brackets) > 0:
340341
multiply_range(square_brackets.pop(), square_bracket_multiplier)
341342
else:
342-
res.append([text, 1.0])
343+
parts = re.split(re_break, text)
344+
for i, part in enumerate(parts):
345+
if i > 0:
346+
res.append(["BREAK", -1])
347+
res.append([part, 1.0])
343348

344349
for pos in round_brackets:
345350
multiply_range(pos, round_bracket_multiplier)

‎modules/sd_hijack_clip.py‎

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,18 @@ def tokenize_line(self, line):
9696
token_count = 0
9797
last_comma = -1
9898

99-
def next_chunk():
100-
"""puts current chunk into the list of results and produces the next one - empty"""
99+
def next_chunk(is_last=False):
100+
"""puts current chunk into the list of results and produces the next one - empty;
101+
if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
101102
nonlocal token_count
102103
nonlocal last_comma
103104
nonlocal chunk
104105

105-
token_count += len(chunk.tokens)
106+
if is_last:
107+
token_count += len(chunk.tokens)
108+
else:
109+
token_count += self.chunk_length
110+
106111
to_add = self.chunk_length - len(chunk.tokens)
107112
if to_add > 0:
108113
chunk.tokens += [self.id_end] * to_add
@@ -116,6 +121,10 @@ def next_chunk():
116121
chunk = PromptChunk()
117122

118123
for tokens, (text, weight) in zip(tokenized, parsed):
124+
if text == 'BREAK' and weight == -1:
125+
next_chunk()
126+
continue
127+
119128
position = 0
120129
while position < len(tokens):
121130
token = tokens[position]
@@ -159,7 +168,7 @@ def next_chunk():
159168
position += embedding_length_in_tokens
160169

161170
if len(chunk.tokens) > 0 or len(chunks) == 0:
162-
next_chunk()
171+
next_chunk(is_last=True)
163172

164173
return chunks, token_count
165174

0 commit comments

Comments
 (0)