Skip to content

Commit d686e73

Browse files
committed
support for SD3: infinite prompt length, token counting
1 parent a8fba9a commit d686e73

6 files changed

Lines changed: 278 additions & 139 deletions

File tree

‎modules/models/sd3/sd3_cond.py‎

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
import os
2+
import safetensors
3+
import torch
4+
import typing
5+
6+
from transformers import CLIPTokenizer, T5TokenizerFast
7+
8+
from modules import shared, devices, modelloader, sd_hijack_clip, prompt_parser
9+
from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer
10+
11+
12+
class SafetensorsMapping(typing.Mapping):
13+
def __init__(self, file):
14+
self.file = file
15+
16+
def __len__(self):
17+
return len(self.file.keys())
18+
19+
def __iter__(self):
20+
for key in self.file.keys():
21+
yield key
22+
23+
def __getitem__(self, key):
24+
return self.file.get_tensor(key)
25+
26+
27+
CLIPL_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors"
28+
CLIPL_CONFIG = {
29+
"hidden_act": "quick_gelu",
30+
"hidden_size": 768,
31+
"intermediate_size": 3072,
32+
"num_attention_heads": 12,
33+
"num_hidden_layers": 12,
34+
}
35+
36+
CLIPG_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors"
37+
CLIPG_CONFIG = {
38+
"hidden_act": "gelu",
39+
"hidden_size": 1280,
40+
"intermediate_size": 5120,
41+
"num_attention_heads": 20,
42+
"num_hidden_layers": 32,
43+
}
44+
45+
T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
46+
T5_CONFIG = {
47+
"d_ff": 10240,
48+
"d_model": 4096,
49+
"num_heads": 64,
50+
"num_layers": 24,
51+
"vocab_size": 32128,
52+
}
53+
54+
55+
class Sd3ClipLG(sd_hijack_clip.TextConditionalModel):
56+
def __init__(self, clip_l, clip_g):
57+
super().__init__()
58+
59+
self.clip_l = clip_l
60+
self.clip_g = clip_g
61+
62+
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
63+
64+
empty = self.tokenizer('')["input_ids"]
65+
self.id_start = empty[0]
66+
self.id_end = empty[1]
67+
self.id_pad = empty[1]
68+
69+
self.return_pooled = True
70+
71+
def tokenize(self, texts):
72+
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
73+
74+
def encode_with_transformers(self, tokens):
75+
tokens_g = tokens.clone()
76+
77+
for batch_pos in range(tokens_g.shape[0]):
78+
index = tokens_g[batch_pos].cpu().tolist().index(self.id_end)
79+
tokens_g[batch_pos, index+1:tokens_g.shape[1]] = 0
80+
81+
l_out, l_pooled = self.clip_l(tokens)
82+
g_out, g_pooled = self.clip_g(tokens_g)
83+
84+
lg_out = torch.cat([l_out, g_out], dim=-1)
85+
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
86+
87+
vector_out = torch.cat((l_pooled, g_pooled), dim=-1)
88+
89+
lg_out.pooled = vector_out
90+
return lg_out
91+
92+
def encode_embedding_init_text(self, init_text, nvpt):
93+
return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX
94+
95+
96+
class Sd3T5(torch.nn.Module):
97+
def __init__(self, t5xxl):
98+
super().__init__()
99+
100+
self.t5xxl = t5xxl
101+
self.tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl")
102+
103+
empty = self.tokenizer('', padding='max_length', max_length=2)["input_ids"]
104+
self.id_end = empty[0]
105+
self.id_pad = empty[1]
106+
107+
def tokenize(self, texts):
108+
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
109+
110+
def tokenize_line(self, line, *, target_token_count=None):
111+
if shared.opts.emphasis != "None":
112+
parsed = prompt_parser.parse_prompt_attention(line)
113+
else:
114+
parsed = [[line, 1.0]]
115+
116+
tokenized = self.tokenize([text for text, _ in parsed])
117+
118+
tokens = []
119+
multipliers = []
120+
121+
for text_tokens, (text, weight) in zip(tokenized, parsed):
122+
if text == 'BREAK' and weight == -1:
123+
continue
124+
125+
tokens += text_tokens
126+
multipliers += [weight] * len(text_tokens)
127+
128+
tokens += [self.id_end]
129+
multipliers += [1.0]
130+
131+
if target_token_count is not None:
132+
if len(tokens) < target_token_count:
133+
tokens += [self.id_pad] * (target_token_count - len(tokens))
134+
multipliers += [1.0] * (target_token_count - len(tokens))
135+
else:
136+
tokens = tokens[0:target_token_count]
137+
multipliers = multipliers[0:target_token_count]
138+
139+
return tokens, multipliers
140+
141+
def forward(self, texts, *, token_count):
142+
if not self.t5xxl or not shared.opts.sd3_enable_t5:
143+
return torch.zeros((len(texts), token_count, 4096), device=devices.device, dtype=devices.dtype)
144+
145+
tokens_batch = []
146+
147+
for text in texts:
148+
tokens, multipliers = self.tokenize_line(text, target_token_count=token_count)
149+
tokens_batch.append(tokens)
150+
151+
t5_out, t5_pooled = self.t5xxl(tokens_batch)
152+
153+
return t5_out
154+
155+
def encode_embedding_init_text(self, init_text, nvpt):
156+
return torch.zeros((nvpt, 4096), device=devices.device) # XXX
157+
158+
159+
class SD3Cond(torch.nn.Module):
160+
def __init__(self, *args, **kwargs):
161+
super().__init__(*args, **kwargs)
162+
163+
self.tokenizer = SD3Tokenizer()
164+
165+
with torch.no_grad():
166+
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
167+
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
168+
169+
if shared.opts.sd3_enable_t5:
170+
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
171+
else:
172+
self.t5xxl = None
173+
174+
self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g)
175+
self.model_t5 = Sd3T5(self.t5xxl)
176+
177+
self.weights_loaded = False
178+
179+
def forward(self, prompts: list[str]):
180+
lg_out, vector_out = self.model_lg(prompts)
181+
182+
token_count = lg_out.shape[1]
183+
184+
t5_out = self.model_t5(prompts, token_count=token_count)
185+
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
186+
187+
return {
188+
'crossattn': lgt_out,
189+
'vector': vector_out,
190+
}
191+
192+
def load_weights(self):
193+
if self.weights_loaded:
194+
return
195+
196+
clip_path = os.path.join(shared.models_path, "CLIP")
197+
198+
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
199+
with safetensors.safe_open(clip_g_file, framework="pt") as file:
200+
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
201+
202+
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
203+
with safetensors.safe_open(clip_l_file, framework="pt") as file:
204+
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
205+
206+
if self.t5xxl:
207+
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
208+
with safetensors.safe_open(t5_file, framework="pt") as file:
209+
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
210+
211+
self.weights_loaded = True
212+
213+
def encode_embedding_init_text(self, init_text, nvpt):
214+
return torch.tensor([[0]], device=devices.device) # XXX
215+
216+
def medvram_modules(self):
217+
return [self.clip_g, self.clip_l, self.t5xxl]
218+
219+
def get_token_count(self, text):
220+
_, token_count = self.model_lg.process_texts([text])
221+
222+
return token_count
223+
224+
def get_target_prompt_token_count(self, token_count):
225+
return self.model_lg.get_target_prompt_token_count(token_count)

‎modules/models/sd3/sd3_model.py‎

Lines changed: 2 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1,127 +1,12 @@
11
import contextlib
2-
import os
3-
from typing import Mapping
42

5-
import safetensors
63
import torch
74

85
import k_diffusion
9-
from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer
106
from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat
7+
from modules.models.sd3.sd3_cond import SD3Cond
118

12-
from modules import shared, modelloader, devices
13-
14-
CLIPG_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors"
15-
CLIPG_CONFIG = {
16-
"hidden_act": "gelu",
17-
"hidden_size": 1280,
18-
"intermediate_size": 5120,
19-
"num_attention_heads": 20,
20-
"num_hidden_layers": 32,
21-
}
22-
23-
CLIPL_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors"
24-
CLIPL_CONFIG = {
25-
"hidden_act": "quick_gelu",
26-
"hidden_size": 768,
27-
"intermediate_size": 3072,
28-
"num_attention_heads": 12,
29-
"num_hidden_layers": 12,
30-
}
31-
32-
T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
33-
T5_CONFIG = {
34-
"d_ff": 10240,
35-
"d_model": 4096,
36-
"num_heads": 64,
37-
"num_layers": 24,
38-
"vocab_size": 32128,
39-
}
40-
41-
42-
class SafetensorsMapping(Mapping):
43-
def __init__(self, file):
44-
self.file = file
45-
46-
def __len__(self):
47-
return len(self.file.keys())
48-
49-
def __iter__(self):
50-
for key in self.file.keys():
51-
yield key
52-
53-
def __getitem__(self, key):
54-
return self.file.get_tensor(key)
55-
56-
57-
class SD3Cond(torch.nn.Module):
58-
def __init__(self, *args, **kwargs):
59-
super().__init__(*args, **kwargs)
60-
61-
self.tokenizer = SD3Tokenizer()
62-
63-
with torch.no_grad():
64-
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
65-
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
66-
67-
if shared.opts.sd3_enable_t5:
68-
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
69-
else:
70-
self.t5xxl = None
71-
72-
self.weights_loaded = False
73-
74-
def forward(self, prompts: list[str]):
75-
res = []
76-
77-
for prompt in prompts:
78-
tokens = self.tokenizer.tokenize_with_weights(prompt)
79-
l_out, l_pooled = self.clip_l.encode_token_weights(tokens["l"])
80-
g_out, g_pooled = self.clip_g.encode_token_weights(tokens["g"])
81-
82-
if self.t5xxl and shared.opts.sd3_enable_t5:
83-
t5_out, t5_pooled = self.t5xxl.encode_token_weights(tokens["t5xxl"])
84-
else:
85-
t5_out = torch.zeros(l_out.shape[0:2] + (4096,), dtype=l_out.dtype, device=l_out.device)
86-
87-
lg_out = torch.cat([l_out, g_out], dim=-1)
88-
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
89-
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
90-
vector_out = torch.cat((l_pooled, g_pooled), dim=-1)
91-
92-
res.append({
93-
'crossattn': lgt_out[0].to(devices.device),
94-
'vector': vector_out[0].to(devices.device),
95-
})
96-
97-
return res
98-
99-
def load_weights(self):
100-
if self.weights_loaded:
101-
return
102-
103-
clip_path = os.path.join(shared.models_path, "CLIP")
104-
105-
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
106-
with safetensors.safe_open(clip_g_file, framework="pt") as file:
107-
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
108-
109-
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
110-
with safetensors.safe_open(clip_l_file, framework="pt") as file:
111-
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
112-
113-
if self.t5xxl:
114-
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
115-
with safetensors.safe_open(t5_file, framework="pt") as file:
116-
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
117-
118-
self.weights_loaded = True
119-
120-
def encode_embedding_init_text(self, init_text, nvpt):
121-
return torch.tensor([[0]], device=devices.device) # XXX
122-
123-
def medvram_modules(self):
124-
return [self.clip_g, self.clip_l, self.t5xxl]
9+
from modules import shared, devices
12510

12611

12712
class SD3Denoiser(k_diffusion.external.DiscreteSchedule):

‎modules/prompt_parser.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None,
268268

269269

270270
class DictWithShape(dict):
271-
def __init__(self, x, shape):
271+
def __init__(self, x, shape=None):
272272
super().__init__()
273273
self.update(x)
274274

‎modules/sd_hijack.py‎

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,10 @@ def get_prompt_lengths(self, text):
325325
if self.clip is None:
326326
return "-", "-"
327327

328-
_, token_count = self.clip.process_texts([text])
328+
if hasattr(self.clip, 'get_token_count'):
329+
token_count = self.clip.get_token_count(text)
330+
else:
331+
_, token_count = self.clip.process_texts([text])
329332

330333
return token_count, self.clip.get_target_prompt_token_count(token_count)
331334

0 commit comments

Comments
 (0)