Skip to content

Commit af08121

Browse files
committed
getting SD2.1 to run on SDXL repo
1 parent 7b83329 commit af08121

9 files changed

Lines changed: 152 additions & 24 deletions

‎modules/launch_utils.py‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,13 @@ def prepare_environment():
235235
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
236236

237237
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
238+
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
238239
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
239240
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
240241
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
241242

242243
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
244+
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
243245
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
244246
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
245247
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
@@ -297,6 +299,7 @@ def prepare_environment():
297299
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
298300

299301
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
302+
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
300303
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
301304
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
302305
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)

‎modules/paths.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
path_dirs = [
2222
(sd_path, 'ldm', 'Stable Diffusion', []),
23+
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', []),
2324
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
2425
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
2526
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),

‎modules/prompt_parser.py‎

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,12 @@ def get_learned_conditioning(model, prompts, steps):
144144

145145
cond_schedule = []
146146
for i, (end_at_step, _) in enumerate(prompt_schedule):
147-
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
147+
if isinstance(conds, dict):
148+
cond = {k: v[i] for k, v in conds.items()}
149+
else:
150+
cond = conds[i]
151+
152+
cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond))
148153

149154
cache[prompt] = cond_schedule
150155
res.append(cond_schedule)
@@ -214,20 +219,57 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
214219
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
215220

216221

222+
class DictWithShape(dict):
223+
def __init__(self, x, shape):
224+
super().__init__()
225+
self.update(x)
226+
227+
@property
228+
def shape(self):
229+
return self["crossattn"].shape
230+
231+
217232
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
218233
param = c[0][0].cond
219-
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
234+
is_dict = isinstance(param, dict)
235+
236+
if is_dict:
237+
dict_cond = param
238+
res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
239+
res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
240+
else:
241+
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
242+
220243
for i, cond_schedule in enumerate(c):
221244
target_index = 0
222245
for current, entry in enumerate(cond_schedule):
223246
if current_step <= entry.end_at_step:
224247
target_index = current
225248
break
226-
res[i] = cond_schedule[target_index].cond
249+
250+
if is_dict:
251+
for k, param in cond_schedule[target_index].cond.items():
252+
res[k][i] = param
253+
else:
254+
res[i] = cond_schedule[target_index].cond
227255

228256
return res
229257

230258

259+
def stack_conds(tensors):
260+
# if prompts have wildly different lengths above the limit we'll get tensors of different shapes
261+
# and won't be able to torch.stack them. So this fixes that.
262+
token_count = max([x.shape[0] for x in tensors])
263+
for i in range(len(tensors)):
264+
if tensors[i].shape[0] != token_count:
265+
last_vector = tensors[i][-1:]
266+
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
267+
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
268+
269+
return torch.stack(tensors)
270+
271+
272+
231273
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
232274
param = c.batch[0][0].schedules[0].cond
233275

@@ -249,16 +291,14 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
249291

250292
conds_list.append(conds_for_batch)
251293

252-
# if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
253-
# and won't be able to torch.stack them. So this fixes that.
254-
token_count = max([x.shape[0] for x in tensors])
255-
for i in range(len(tensors)):
256-
if tensors[i].shape[0] != token_count:
257-
last_vector = tensors[i][-1:]
258-
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
259-
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
294+
if isinstance(tensors[0], dict):
295+
keys = list(tensors[0].keys())
296+
stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
297+
stacked = DictWithShape(stacked, stacked['crossattn'].shape)
298+
else:
299+
stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
260300

261-
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
301+
return conds_list, stacked
262302

263303

264304
re_attention = re.compile(r"""

‎modules/sd_hijack.py‎

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,15 @@ def apply_optimizations(self, option=None):
166166
undo_optimizations()
167167

168168
def hijack(self, m):
169+
conditioner = getattr(m, 'conditioner', None)
170+
if conditioner:
171+
for i in range(len(conditioner.embedders)):
172+
embedder = conditioner.embedders[i]
173+
if type(embedder).__name__ == 'FrozenOpenCLIPEmbedder':
174+
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
175+
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
176+
conditioner.embedders[i] = m.cond_stage_model
177+
169178
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
170179
model_embeddings = m.cond_stage_model.roberta.embeddings
171180
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)

‎modules/sd_hijack_open_clip.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ def __init__(self, wrapped, hijack):
1616
self.id_end = tokenizer.encoder["<end_of_text>"]
1717
self.id_pad = 0
1818

19+
self.is_trainable = getattr(wrapped, 'is_trainable', False)
20+
self.input_key = getattr(wrapped, 'input_key', 'txt')
21+
self.legacy_ucg_val = None
22+
1923
def tokenize(self, texts):
2024
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
2125

‎modules/sd_models.py‎

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from ldm.util import instantiate_from_config
1616

17-
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
17+
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
1818
from modules.sd_hijack_inpainting import do_inpainting_hijack
1919
from modules.timer import Timer
2020
import tomesd
@@ -289,6 +289,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
289289
if state_dict is None:
290290
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
291291

292+
if hasattr(model, 'conditioner'):
293+
sd_models_xl.extend_sdxl(model)
294+
292295
model.load_state_dict(state_dict, strict=False)
293296
del state_dict
294297
timer.record("apply weights to model")
@@ -334,7 +337,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
334337
model.sd_checkpoint_info = checkpoint_info
335338
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
336339

337-
model.logvar = model.logvar.to(devices.device) # fix for training
340+
if hasattr(model, 'logvar'):
341+
model.logvar = model.logvar.to(devices.device) # fix for training
338342

339343
sd_vae.delete_base_vae()
340344
sd_vae.clear_loaded_vae()

‎modules/sd_models_config.py‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66

77
sd_configs_path = shared.sd_configs_path
88
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
9+
sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
910

1011

1112
config_default = shared.sd_default_config
1213
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
1314
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
15+
config_sd2v = os.path.join(sd_xl_repo_configs_path, "sd_2_1_768.yaml")
1416
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
1517
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
1618
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")

‎modules/sd_models_xl.py‎

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import sgm.models.diffusion
6+
import sgm.modules.diffusionmodules.denoiser_scaling
7+
import sgm.modules.diffusionmodules.discretizer
8+
from modules import devices
9+
10+
11+
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: list[str]):
12+
for embedder in self.conditioner.embedders:
13+
embedder.ucg_rate = 0.0
14+
15+
c = self.conditioner({'txt': batch})
16+
17+
return c
18+
19+
20+
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
21+
return self.model(x, t, cond)
22+
23+
24+
def extend_sdxl(model):
25+
dtype = next(model.model.diffusion_model.parameters()).dtype
26+
model.model.diffusion_model.dtype = dtype
27+
model.model.conditioning_key = 'crossattn'
28+
29+
model.cond_stage_model = [x for x in model.conditioner.embedders if type(x).__name__ == 'FrozenOpenCLIPEmbedder'][0]
30+
model.cond_stage_key = model.cond_stage_model.input_key
31+
32+
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
33+
34+
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
35+
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
36+
37+
38+
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
39+
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
40+

‎modules/sd_samplers_kdiffusion.py‎

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,28 @@
5353
}
5454

5555

56+
def catenate_conds(conds):
57+
if not isinstance(conds[0], dict):
58+
return torch.cat(conds)
59+
60+
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
61+
62+
63+
def subscript_cond(cond, a, b):
64+
if not isinstance(cond, dict):
65+
return cond[a:b]
66+
67+
return {key: vec[a:b] for key, vec in cond.items()}
68+
69+
70+
def pad_cond(tensor, repeats, empty):
71+
if not isinstance(tensor, dict):
72+
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
73+
74+
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
75+
return tensor
76+
77+
5678
class CFGDenoiser(torch.nn.Module):
5779
"""
5880
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
@@ -105,10 +127,13 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
105127

106128
if shared.sd_model.model.conditioning_key == "crossattn-adm":
107129
image_uncond = torch.zeros_like(image_cond)
108-
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
130+
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
109131
else:
110132
image_uncond = image_cond
111-
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
133+
if isinstance(uncond, dict):
134+
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
135+
else:
136+
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
112137

113138
if not is_edit_model:
114139
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
@@ -140,28 +165,28 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
140165
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
141166

142167
if num_repeats < 0:
143-
tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1)
168+
tensor = pad_cond(tensor, -num_repeats, empty)
144169
self.padded_cond_uncond = True
145170
elif num_repeats > 0:
146-
uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1)
171+
uncond = pad_cond(uncond, num_repeats, empty)
147172
self.padded_cond_uncond = True
148173

149174
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
150175
if is_edit_model:
151-
cond_in = torch.cat([tensor, uncond, uncond])
176+
cond_in = catenate_conds([tensor, uncond, uncond])
152177
elif skip_uncond:
153178
cond_in = tensor
154179
else:
155-
cond_in = torch.cat([tensor, uncond])
180+
cond_in = catenate_conds([tensor, uncond])
156181

157182
if shared.batch_cond_uncond:
158-
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
183+
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
159184
else:
160185
x_out = torch.zeros_like(x_in)
161186
for batch_offset in range(0, x_out.shape[0], batch_size):
162187
a = batch_offset
163188
b = a + batch_size
164-
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b]))
189+
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(cond_in[a:b], image_cond_in[a:b]))
165190
else:
166191
x_out = torch.zeros_like(x_in)
167192
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
@@ -170,14 +195,14 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
170195
b = min(a + batch_size, tensor.shape[0])
171196

172197
if not is_edit_model:
173-
c_crossattn = [tensor[a:b]]
198+
c_crossattn = subscript_cond(tensor, a, b)
174199
else:
175200
c_crossattn = torch.cat([tensor[a:b]], uncond)
176201

177202
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
178203

179204
if not skip_uncond:
180-
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
205+
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
181206

182207
denoised_image_indexes = [x[0][0] for x in conds_list]
183208
if skip_uncond:

0 commit comments

Comments
 (0)