Skip to content

Commit da464a3

Browse files
committed
SDXL support
1 parent af08121 commit da464a3

16 files changed

Lines changed: 242 additions & 45 deletions

‎modules/launch_utils.py‎

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,20 @@ def run_extensions_installers(settings_file):
224224
run_extension_installer(os.path.join(extensions_dir, dirname_extension))
225225

226226

227+
def mute_sdxl_imports():
228+
"""create fake modules that SDXL wants to import but doesn't actually use for our purposes"""
229+
230+
import importlib
231+
232+
module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('taming.modules.losses.lpips', None))
233+
module.LPIPS = None
234+
sys.modules['taming.modules.losses.lpips'] = module
235+
236+
module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('sgm.data', None))
237+
module.StableDataModuleFromConfig = None
238+
sys.modules['sgm.data'] = module
239+
240+
227241
def prepare_environment():
228242
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
229243
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
@@ -319,11 +333,14 @@ def prepare_environment():
319333
if args.update_all_extensions:
320334
git_pull_recursive(extensions_dir)
321335

336+
mute_sdxl_imports()
337+
322338
if "--exit" in sys.argv:
323339
print("Exiting because of --exit argument")
324340
exit(0)
325341

326342

343+
327344
def configure_for_tests():
328345
if "--api" not in sys.argv:
329346
sys.argv.append("--api")

‎modules/lowvram.py‎

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,46 @@ def first_stage_model_decode_wrap(z):
5353
send_me_to_gpu(first_stage_model, None)
5454
return first_stage_model_decode(z)
5555

56-
# for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
57-
if hasattr(sd_model.cond_stage_model, 'model'):
58-
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
59-
60-
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
61-
# send the model to GPU. Then put modules back. the modules will be in CPU.
62-
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model
63-
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
56+
to_remain_in_cpu = [
57+
(sd_model, 'first_stage_model'),
58+
(sd_model, 'depth_model'),
59+
(sd_model, 'embedder'),
60+
(sd_model, 'model'),
61+
(sd_model, 'embedder'),
62+
]
63+
64+
is_sdxl = hasattr(sd_model, 'conditioner')
65+
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
66+
67+
if is_sdxl:
68+
to_remain_in_cpu.append((sd_model, 'conditioner'))
69+
elif is_sd2:
70+
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
71+
else:
72+
to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))
73+
74+
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
75+
stored = []
76+
for obj, field in to_remain_in_cpu:
77+
module = getattr(obj, field, None)
78+
stored.append(module)
79+
setattr(obj, field, None)
80+
81+
# send the model to GPU.
6482
sd_model.to(devices.device)
65-
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored
83+
84+
# put modules back. the modules will be in CPU.
85+
for (obj, field), module in zip(to_remain_in_cpu, stored):
86+
setattr(obj, field, module)
6687

6788
# register hooks for those the first three models
68-
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
89+
if is_sdxl:
90+
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
91+
elif is_sd2:
92+
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
93+
else:
94+
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
95+
6996
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
7097
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
7198
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
@@ -75,10 +102,6 @@ def first_stage_model_decode_wrap(z):
75102
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
76103
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
77104

78-
if hasattr(sd_model.cond_stage_model, 'model'):
79-
sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
80-
del sd_model.cond_stage_model.transformer
81-
82105
if use_medvram:
83106
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
84107
else:

‎modules/paths.py‎

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
path_dirs = [
2222
(sd_path, 'ldm', 'Stable Diffusion', []),
23-
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', []),
23+
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
2424
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
2525
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
2626
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
@@ -36,6 +36,13 @@
3636
d = os.path.abspath(d)
3737
if "atstart" in options:
3838
sys.path.insert(0, d)
39+
elif "sgm" in options:
40+
# Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
41+
# import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.
42+
43+
sys.path.insert(0, d)
44+
import sgm
45+
sys.path.pop(0)
3946
else:
4047
sys.path.append(d)
4148
paths[what] = d

‎modules/processing.py‎

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,10 +343,13 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr
343343
return cache[1]
344344

345345
def setup_conds(self):
346+
prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
347+
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height)
348+
346349
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
347350
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
348-
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
349-
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
351+
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
352+
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
350353

351354
def parse_extra_network_prompts(self):
352355
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)

‎modules/prompt_parser.py‎

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import re
24
from collections import namedtuple
35
from typing import List
@@ -109,7 +111,19 @@ def get_schedule(prompt):
109111
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
110112

111113

112-
def get_learned_conditioning(model, prompts, steps):
114+
class SdConditioning(list):
115+
"""
116+
A list with prompts for stable diffusion's conditioner model.
117+
Can also specify width and height of created image - SDXL needs it.
118+
"""
119+
def __init__(self, prompts, width=None, height=None):
120+
super().__init__()
121+
self.extend(prompts)
122+
self.width = width or getattr(prompts, 'width', None)
123+
self.height = height or getattr(prompts, 'height', None)
124+
125+
126+
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
113127
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
114128
and the sampling step at which this condition is to be replaced by the next one.
115129
@@ -160,11 +174,13 @@ def get_learned_conditioning(model, prompts, steps):
160174
re_AND = re.compile(r"\bAND\b")
161175
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
162176

163-
def get_multicond_prompt_list(prompts):
177+
178+
def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
164179
res_indexes = []
165180

166-
prompt_flat_list = []
167181
prompt_indexes = {}
182+
prompt_flat_list = SdConditioning(prompts)
183+
prompt_flat_list.clear()
168184

169185
for prompt in prompts:
170186
subprompts = re_AND.split(prompt)
@@ -201,6 +217,7 @@ def __init__(self, shape, batch):
201217
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
202218
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
203219

220+
204221
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
205222
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
206223
For each prompt, the list is obtained by splitting the prompt using the AND separator.

‎modules/sd_hijack.py‎

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
import ldm.models.diffusion.plms
1616
import ldm.modules.encoders.modules
1717

18+
import sgm.modules.attention
19+
import sgm.modules.diffusionmodules.model
20+
import sgm.modules.diffusionmodules.openaimodel
21+
import sgm.modules.encoders.modules
22+
1823
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
1924
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
2025
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
@@ -56,6 +61,9 @@ def apply_optimizations(option=None):
5661
ldm.modules.diffusionmodules.model.nonlinearity = silu
5762
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
5863

64+
sgm.modules.diffusionmodules.model.nonlinearity = silu
65+
sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
66+
5967
if current_optimizer is not None:
6068
current_optimizer.undo()
6169
current_optimizer = None
@@ -89,6 +97,10 @@ def undo_optimizations():
8997
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
9098
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
9199

100+
sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
101+
sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
102+
sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
103+
92104

93105
def fix_checkpoint():
94106
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
@@ -170,10 +182,19 @@ def hijack(self, m):
170182
if conditioner:
171183
for i in range(len(conditioner.embedders)):
172184
embedder = conditioner.embedders[i]
173-
if type(embedder).__name__ == 'FrozenOpenCLIPEmbedder':
185+
typename = type(embedder).__name__
186+
if typename == 'FrozenOpenCLIPEmbedder':
174187
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
175188
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
176189
conditioner.embedders[i] = m.cond_stage_model
190+
if typename == 'FrozenCLIPEmbedder':
191+
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
192+
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
193+
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(embedder, self)
194+
conditioner.embedders[i] = m.cond_stage_model
195+
if typename == 'FrozenOpenCLIPEmbedder2':
196+
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
197+
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
177198

178199
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
179200
model_embeddings = m.cond_stage_model.roberta.embeddings

‎modules/sd_hijack_clip.py‎

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def __init__(self, wrapped, hijack):
4242
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
4343
self.chunk_length = 75
4444

45+
self.is_trainable = getattr(wrapped, 'is_trainable', False)
46+
self.input_key = getattr(wrapped, 'input_key', 'txt')
47+
self.legacy_ucg_val = None
48+
4549
def empty_chunk(self):
4650
"""creates an empty PromptChunk and returns it"""
4751

@@ -199,8 +203,9 @@ def forward(self, texts):
199203
"""
200204
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
201205
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
202-
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
206+
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.
203207
An example shape returned by this function can be: (2, 77, 768).
208+
For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.
204209
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
205210
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
206211
"""
@@ -233,7 +238,10 @@ def forward(self, texts):
233238
embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
234239
self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
235240

236-
return torch.hstack(zs)
241+
if getattr(self.wrapped, 'return_pooled', False):
242+
return torch.hstack(zs), zs[0].pooled
243+
else:
244+
return torch.hstack(zs)
237245

238246
def process_tokens(self, remade_batch_tokens, batch_multipliers):
239247
"""
@@ -256,9 +264,9 @@ def process_tokens(self, remade_batch_tokens, batch_multipliers):
256264
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
257265
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
258266
original_mean = z.mean()
259-
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
267+
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
260268
new_mean = z.mean()
261-
z = z * (original_mean / new_mean)
269+
z *= (original_mean / new_mean)
262270

263271
return z
264272

‎modules/sd_hijack_open_clip.py‎

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@ def __init__(self, wrapped, hijack):
1616
self.id_end = tokenizer.encoder["<end_of_text>"]
1717
self.id_pad = 0
1818

19-
self.is_trainable = getattr(wrapped, 'is_trainable', False)
20-
self.input_key = getattr(wrapped, 'input_key', 'txt')
21-
self.legacy_ucg_val = None
22-
2319
def tokenize(self, texts):
2420
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
2521

@@ -39,3 +35,37 @@ def encode_embedding_init_text(self, init_text, nvpt):
3935
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
4036

4137
return embedded
38+
39+
40+
class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
41+
def __init__(self, wrapped, hijack):
42+
super().__init__(wrapped, hijack)
43+
44+
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
45+
self.id_start = tokenizer.encoder["<start_of_text>"]
46+
self.id_end = tokenizer.encoder["<end_of_text>"]
47+
self.id_pad = 0
48+
49+
def tokenize(self, texts):
50+
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
51+
52+
tokenized = [tokenizer.encode(text) for text in texts]
53+
54+
return tokenized
55+
56+
def encode_with_transformers(self, tokens):
57+
d = self.wrapped.encode_with_transformer(tokens)
58+
z = d[self.wrapped.layer]
59+
60+
pooled = d.get("pooled")
61+
if pooled is not None:
62+
z.pooled = pooled
63+
64+
return z
65+
66+
def encode_embedding_init_text(self, init_text, nvpt):
67+
ids = tokenizer.encode(init_text)
68+
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
69+
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
70+
71+
return embedded

0 commit comments

Comments
 (0)