Skip to content

Commit 594c8e7

Browse files
committed
fix CLIP doing the unneeded normalization
revert SD2.1 back to use the original repo add SDXL's force_zero_embeddings to negative prompt
1 parent 21aec6f commit 594c8e7

6 files changed

Lines changed: 29 additions & 8 deletions

File tree

‎modules/processing.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr
344344

345345
def setup_conds(self):
346346
prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
347-
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height)
347+
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
348348

349349
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
350350
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1

‎modules/prompt_parser.py‎

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,17 @@ class SdConditioning(list):
116116
A list with prompts for stable diffusion's conditioner model.
117117
Can also specify width and height of created image - SDXL needs it.
118118
"""
119-
def __init__(self, prompts, width=None, height=None):
119+
def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None):
120120
super().__init__()
121121
self.extend(prompts)
122-
self.width = width or getattr(prompts, 'width', None)
123-
self.height = height or getattr(prompts, 'height', None)
122+
123+
if copy_from is None:
124+
copy_from = prompts
125+
126+
self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False)
127+
self.width = width or getattr(copy_from, 'width', None)
128+
self.height = height or getattr(copy_from, 'height', None)
129+
124130

125131

126132
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
@@ -153,7 +159,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
153159
res.append(cached)
154160
continue
155161

156-
texts = [x[1] for x in prompt_schedule]
162+
texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts)
157163
conds = model.get_learned_conditioning(texts)
158164

159165
cond_schedule = []

‎modules/sd_hijack.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def hijack(self, m):
190190
if typename == 'FrozenCLIPEmbedder':
191191
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
192192
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
193-
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(embedder, self)
193+
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
194194
conditioner.embedders[i] = m.cond_stage_model
195195
if typename == 'FrozenOpenCLIPEmbedder2':
196196
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)

‎modules/sd_hijack_clip.py‎

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,18 @@ def encode_embedding_init_text(self, init_text, nvpt):
323323
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
324324

325325
return embedded
326+
327+
328+
class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords):
329+
def __init__(self, wrapped, hijack):
330+
super().__init__(wrapped, hijack)
331+
332+
def encode_with_transformers(self, tokens):
333+
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden")
334+
335+
if self.wrapped.layer == "last":
336+
z = outputs.last_hidden_state
337+
else:
338+
z = outputs.hidden_states[self.wrapped.layer_idx]
339+
340+
return z

‎modules/sd_models_config.py‎

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
config_default = shared.sd_default_config
1313
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
1414
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
15-
config_sd2v = os.path.join(sd_xl_repo_configs_path, "sd_2_1_768.yaml")
1615
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
1716
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
1817
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")

‎modules/sd_models_xl.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
2222
"target_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype),
2323
}
2424

25-
c = self.conditioner(sdxl_conds)
25+
force_zero_negative_prompt = getattr(batch, 'is_negative_prompt', False) and all(x == '' for x in batch)
26+
c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
2627

2728
return c
2829

0 commit comments

Comments
 (0)