Skip to content

Commit 0027ce1

Browse files
Merge pull request #12457 from rubberbaron/shared-hires-prompt-test
prompt editing timeline has separate range for first pass and hires-fix pass
2 parents 06f1818 + 99ab3d4 commit 0027ce1

3 files changed

Lines changed: 42 additions & 16 deletions

File tree

‎modules/processing.py‎

100755100644
Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -407,12 +407,14 @@ def setup_prompts(self):
407407
self.main_prompt = self.all_prompts[0]
408408
self.main_negative_prompt = self.all_negative_prompts[0]
409409

410-
def cached_params(self, required_prompts, steps, extra_network_data):
410+
def cached_params(self, required_prompts, steps, extra_network_data, hires_steps=None, use_old_scheduling=False):
411411
"""Returns parameters that invalidate the cond cache if changed"""
412412

413413
return (
414414
required_prompts,
415415
steps,
416+
hires_steps,
417+
use_old_scheduling,
416418
opts.CLIP_stop_at_last_layers,
417419
shared.sd_model.sd_checkpoint_info,
418420
extra_network_data,
@@ -422,7 +424,7 @@ def cached_params(self, required_prompts, steps, extra_network_data):
422424
self.height,
423425
)
424426

425-
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
427+
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
426428
"""
427429
Returns the result of calling function(shared.sd_model, required_prompts, steps)
428430
using a cache to store the result if the same arguments have been used before.
@@ -435,7 +437,7 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr
435437
caches is a list with items described above.
436438
"""
437439

438-
cached_params = self.cached_params(required_prompts, steps, extra_network_data)
440+
cached_params = self.cached_params(required_prompts, steps, extra_network_data, hires_steps, shared.opts.use_old_scheduling)
439441

440442
for cache in caches:
441443
if cache[0] is not None and cached_params == cache[0]:
@@ -444,7 +446,7 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr
444446
cache = caches[0]
445447

446448
with devices.autocast():
447-
cache[1] = function(shared.sd_model, required_prompts, steps)
449+
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
448450

449451
cache[0] = cached_params
450452
return cache[1]
@@ -456,6 +458,8 @@ def setup_conds(self):
456458
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
457459
total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps
458460
self.step_multiplier = total_steps // self.steps
461+
self.firstpass_steps = total_steps
462+
459463
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
460464
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
461465

@@ -1292,8 +1296,8 @@ def calculate_hr_conds(self):
12921296
steps = self.hr_second_pass_steps or self.steps
12931297
total_steps = sampler_config.total_steps(steps) if sampler_config else steps
12941298

1295-
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, total_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
1296-
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, total_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
1299+
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
1300+
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
12971301

12981302
def setup_conds(self):
12991303
if self.is_hr_pass:

‎modules/prompt_parser.py‎

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
%import common.SIGNED_NUMBER -> NUMBER
2727
""")
2828

29-
def get_learned_conditioning_prompt_schedules(prompts, steps):
29+
def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=None, use_old_scheduling=False):
3030
"""
3131
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
3232
>>> g("test")
@@ -57,18 +57,39 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
5757
[[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
5858
>>> g("[fe|||]male")
5959
[[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
60+
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10, 10)[0]
61+
>>> g("a [b:.5] c")
62+
[[10, 'a b c']]
63+
>>> g("a [b:1.5] c")
64+
[[5, 'a c'], [10, 'a b c']]
6065
"""
6166

67+
if hires_steps is None or use_old_scheduling:
68+
int_offset = 0
69+
flt_offset = 0
70+
steps = base_steps
71+
else:
72+
int_offset = base_steps
73+
flt_offset = 1.0
74+
steps = hires_steps
75+
6276
def collect_steps(steps, tree):
6377
res = [steps]
6478

6579
class CollectSteps(lark.Visitor):
6680
def scheduled(self, tree):
67-
tree.children[-2] = float(tree.children[-2])
68-
if tree.children[-2] < 1:
69-
tree.children[-2] *= steps
70-
tree.children[-2] = min(steps, int(tree.children[-2]))
71-
res.append(tree.children[-2])
81+
s = tree.children[-2]
82+
v = float(s)
83+
if use_old_scheduling:
84+
v = v*steps if v<1 else v
85+
else:
86+
if "." in s:
87+
v = (v - flt_offset) * steps
88+
else:
89+
v = (v - int_offset)
90+
tree.children[-2] = min(steps, int(v))
91+
if tree.children[-2] >= 1:
92+
res.append(tree.children[-2])
7293

7394
def alternate(self, tree):
7495
res.extend(range(1, steps+1))
@@ -134,7 +155,7 @@ def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, c
134155

135156

136157

137-
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
158+
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False):
138159
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
139160
and the sampling step at which this condition is to be replaced by the next one.
140161
@@ -154,7 +175,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
154175
"""
155176
res = []
156177

157-
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
178+
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling)
158179
cache = {}
159180

160181
for prompt, prompt_schedule in zip(prompts, prompt_schedules):
@@ -229,7 +250,7 @@ def __init__(self, shape, batch):
229250
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
230251

231252

232-
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
253+
def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
233254
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
234255
For each prompt, the list is obtained by splitting the prompt using the AND separator.
235256
@@ -238,7 +259,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
238259

239260
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
240261

241-
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
262+
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling)
242263

243264
res = []
244265
for indexes in res_indexes:

‎modules/shared_options.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@
203203
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
204204
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
205205
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
206+
"use_old_scheduling": OptionInfo(False, "Use old prompt where first pass and hires both used the same timeline, and < 1 meant relative and >= 1 meant absolute"),
206207
}))
207208

208209
options_templates.update(options_section(('interrogate', "Interrogate"), {

0 commit comments

Comments
 (0)