Skip to content

Commit 5f24b7b

Browse files
committed
option to let users select which samplers they want to hide
1 parent 6e7057b commit 5f24b7b

4 files changed

Lines changed: 35 additions & 16 deletions

File tree

‎modules/processing.py‎

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
from skimage import exposure
1212

1313
import modules.sd_hijack
14-
from modules import devices, prompt_parser, masking
14+
from modules import devices, prompt_parser, masking, sd_samplers
1515
from modules.sd_hijack import model_hijack
16-
from modules.sd_samplers import samplers, samplers_for_img2img
1716
from modules.shared import opts, cmd_opts, state
1817
import modules.shared as shared
1918
import modules.face_restoration
@@ -110,7 +109,7 @@ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="",
110109
self.width = p.width
111110
self.height = p.height
112111
self.sampler_index = p.sampler_index
113-
self.sampler = samplers[p.sampler_index].name
112+
self.sampler = sd_samplers.samplers[p.sampler_index].name
114113
self.cfg_scale = p.cfg_scale
115114
self.steps = p.steps
116115
self.batch_size = p.batch_size
@@ -265,7 +264,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
265264

266265
generation_params = {
267266
"Steps": p.steps,
268-
"Sampler": samplers[p.sampler_index].name,
267+
"Sampler": sd_samplers.samplers[p.sampler_index].name,
269268
"CFG scale": p.cfg_scale,
270269
"Seed": all_seeds[index],
271270
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
@@ -478,7 +477,7 @@ def init(self, all_prompts, all_seeds, all_subseeds):
478477
self.firstphase_height_truncated = int(scale * self.height)
479478

480479
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
481-
self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
480+
self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model)
482481

483482
if not self.enable_hr:
484483
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
@@ -521,7 +520,7 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
521520

522521
shared.state.nextjob()
523522

524-
self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
523+
self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model)
525524
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
526525

527526
# GC now before running the next img2img to prevent running out of memory
@@ -556,7 +555,7 @@ def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mas
556555
self.nmask = None
557556

558557
def init(self, all_prompts, all_seeds, all_subseeds):
559-
self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
558+
self.sampler = sd_samplers.samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
560559
crop_region = None
561560

562561
if self.image_mask is not None:

‎modules/sd_samplers.py‎

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,27 @@
3232
if hasattr(k_diffusion.sampling, funcname)
3333
]
3434

35-
samplers = [
35+
all_samplers = [
3636
*samplers_data_k_diffusion,
3737
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
3838
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
3939
]
40-
samplers_for_img2img = [x for x in samplers if x.name not in ['PLMS', 'DPM fast', 'DPM adaptive']]
40+
41+
samplers = []
42+
samplers_for_img2img = []
43+
44+
45+
def set_samplers():
46+
global samplers, samplers_for_img2img
47+
48+
hidden = set(opts.hide_samplers)
49+
hidden_img2img = set(opts.hide_samplers + ['PLMS', 'DPM fast', 'DPM adaptive'])
50+
51+
samplers = [x for x in all_samplers if x.name not in hidden]
52+
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
53+
54+
55+
set_samplers()
4156

4257
sampler_extra_params = {
4358
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],

‎modules/shared.py‎

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import modules.sd_models
1414
import modules.styles
1515
import modules.devices as devices
16+
from modules import sd_samplers
1617
from modules.paths import script_path, sd_path
1718

1819
sd_model_file = os.path.join(script_path, 'model.ckpt')
@@ -238,14 +239,16 @@ def options_section(section_identifer, options_dict):
238239
}))
239240

240241
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
241-
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
242-
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
243-
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
244-
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
245-
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
246-
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
242+
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}),
243+
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
244+
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
245+
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
246+
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
247+
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
248+
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
247249
}))
248250

251+
249252
class Options:
250253
data = None
251254
data_labels = options_templates

‎webui.py‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import threading
33
import time
44
import importlib
5-
from modules import devices
5+
from modules import devices, sd_samplers
66
from modules.paths import script_path
77
import signal
88
import threading
@@ -109,6 +109,8 @@ def sigint_handler(sig, frame):
109109
time.sleep(0.5)
110110
break
111111

112+
sd_samplers.set_samplers()
113+
112114
print('Reloading Custom Scripts')
113115
modules.scripts.reload_scripts(os.path.join(script_path, "scripts"))
114116
print('Reloading modules: modules.ui')

0 commit comments

Comments
 (0)