|
| 1 | +import os |
| 2 | +import safetensors |
| 3 | +import torch |
| 4 | +import typing |
| 5 | + |
| 6 | +from transformers import CLIPTokenizer, T5TokenizerFast |
| 7 | + |
| 8 | +from modules import shared, devices, modelloader, sd_hijack_clip, prompt_parser |
| 9 | +from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer |
| 10 | + |
| 11 | + |
| 12 | +class SafetensorsMapping(typing.Mapping): |
| 13 | + def __init__(self, file): |
| 14 | + self.file = file |
| 15 | + |
| 16 | + def __len__(self): |
| 17 | + return len(self.file.keys()) |
| 18 | + |
| 19 | + def __iter__(self): |
| 20 | + for key in self.file.keys(): |
| 21 | + yield key |
| 22 | + |
| 23 | + def __getitem__(self, key): |
| 24 | + return self.file.get_tensor(key) |
| 25 | + |
| 26 | + |
| 27 | +CLIPL_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors" |
| 28 | +CLIPL_CONFIG = { |
| 29 | + "hidden_act": "quick_gelu", |
| 30 | + "hidden_size": 768, |
| 31 | + "intermediate_size": 3072, |
| 32 | + "num_attention_heads": 12, |
| 33 | + "num_hidden_layers": 12, |
| 34 | +} |
| 35 | + |
| 36 | +CLIPG_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors" |
| 37 | +CLIPG_CONFIG = { |
| 38 | + "hidden_act": "gelu", |
| 39 | + "hidden_size": 1280, |
| 40 | + "intermediate_size": 5120, |
| 41 | + "num_attention_heads": 20, |
| 42 | + "num_hidden_layers": 32, |
| 43 | +} |
| 44 | + |
| 45 | +T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors" |
| 46 | +T5_CONFIG = { |
| 47 | + "d_ff": 10240, |
| 48 | + "d_model": 4096, |
| 49 | + "num_heads": 64, |
| 50 | + "num_layers": 24, |
| 51 | + "vocab_size": 32128, |
| 52 | +} |
| 53 | + |
| 54 | + |
| 55 | +class Sd3ClipLG(sd_hijack_clip.TextConditionalModel): |
| 56 | + def __init__(self, clip_l, clip_g): |
| 57 | + super().__init__() |
| 58 | + |
| 59 | + self.clip_l = clip_l |
| 60 | + self.clip_g = clip_g |
| 61 | + |
| 62 | + self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
| 63 | + |
| 64 | + empty = self.tokenizer('')["input_ids"] |
| 65 | + self.id_start = empty[0] |
| 66 | + self.id_end = empty[1] |
| 67 | + self.id_pad = empty[1] |
| 68 | + |
| 69 | + self.return_pooled = True |
| 70 | + |
| 71 | + def tokenize(self, texts): |
| 72 | + return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] |
| 73 | + |
| 74 | + def encode_with_transformers(self, tokens): |
| 75 | + tokens_g = tokens.clone() |
| 76 | + |
| 77 | + for batch_pos in range(tokens_g.shape[0]): |
| 78 | + index = tokens_g[batch_pos].cpu().tolist().index(self.id_end) |
| 79 | + tokens_g[batch_pos, index+1:tokens_g.shape[1]] = 0 |
| 80 | + |
| 81 | + l_out, l_pooled = self.clip_l(tokens) |
| 82 | + g_out, g_pooled = self.clip_g(tokens_g) |
| 83 | + |
| 84 | + lg_out = torch.cat([l_out, g_out], dim=-1) |
| 85 | + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) |
| 86 | + |
| 87 | + vector_out = torch.cat((l_pooled, g_pooled), dim=-1) |
| 88 | + |
| 89 | + lg_out.pooled = vector_out |
| 90 | + return lg_out |
| 91 | + |
| 92 | + def encode_embedding_init_text(self, init_text, nvpt): |
| 93 | + return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX |
| 94 | + |
| 95 | + |
| 96 | +class Sd3T5(torch.nn.Module): |
| 97 | + def __init__(self, t5xxl): |
| 98 | + super().__init__() |
| 99 | + |
| 100 | + self.t5xxl = t5xxl |
| 101 | + self.tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl") |
| 102 | + |
| 103 | + empty = self.tokenizer('', padding='max_length', max_length=2)["input_ids"] |
| 104 | + self.id_end = empty[0] |
| 105 | + self.id_pad = empty[1] |
| 106 | + |
| 107 | + def tokenize(self, texts): |
| 108 | + return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] |
| 109 | + |
| 110 | + def tokenize_line(self, line, *, target_token_count=None): |
| 111 | + if shared.opts.emphasis != "None": |
| 112 | + parsed = prompt_parser.parse_prompt_attention(line) |
| 113 | + else: |
| 114 | + parsed = [[line, 1.0]] |
| 115 | + |
| 116 | + tokenized = self.tokenize([text for text, _ in parsed]) |
| 117 | + |
| 118 | + tokens = [] |
| 119 | + multipliers = [] |
| 120 | + |
| 121 | + for text_tokens, (text, weight) in zip(tokenized, parsed): |
| 122 | + if text == 'BREAK' and weight == -1: |
| 123 | + continue |
| 124 | + |
| 125 | + tokens += text_tokens |
| 126 | + multipliers += [weight] * len(text_tokens) |
| 127 | + |
| 128 | + tokens += [self.id_end] |
| 129 | + multipliers += [1.0] |
| 130 | + |
| 131 | + if target_token_count is not None: |
| 132 | + if len(tokens) < target_token_count: |
| 133 | + tokens += [self.id_pad] * (target_token_count - len(tokens)) |
| 134 | + multipliers += [1.0] * (target_token_count - len(tokens)) |
| 135 | + else: |
| 136 | + tokens = tokens[0:target_token_count] |
| 137 | + multipliers = multipliers[0:target_token_count] |
| 138 | + |
| 139 | + return tokens, multipliers |
| 140 | + |
| 141 | + def forward(self, texts, *, token_count): |
| 142 | + if not self.t5xxl or not shared.opts.sd3_enable_t5: |
| 143 | + return torch.zeros((len(texts), token_count, 4096), device=devices.device, dtype=devices.dtype) |
| 144 | + |
| 145 | + tokens_batch = [] |
| 146 | + |
| 147 | + for text in texts: |
| 148 | + tokens, multipliers = self.tokenize_line(text, target_token_count=token_count) |
| 149 | + tokens_batch.append(tokens) |
| 150 | + |
| 151 | + t5_out, t5_pooled = self.t5xxl(tokens_batch) |
| 152 | + |
| 153 | + return t5_out |
| 154 | + |
| 155 | + def encode_embedding_init_text(self, init_text, nvpt): |
| 156 | + return torch.zeros((nvpt, 4096), device=devices.device) # XXX |
| 157 | + |
| 158 | + |
| 159 | +class SD3Cond(torch.nn.Module): |
| 160 | + def __init__(self, *args, **kwargs): |
| 161 | + super().__init__(*args, **kwargs) |
| 162 | + |
| 163 | + self.tokenizer = SD3Tokenizer() |
| 164 | + |
| 165 | + with torch.no_grad(): |
| 166 | + self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype) |
| 167 | + self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) |
| 168 | + |
| 169 | + if shared.opts.sd3_enable_t5: |
| 170 | + self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype) |
| 171 | + else: |
| 172 | + self.t5xxl = None |
| 173 | + |
| 174 | + self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g) |
| 175 | + self.model_t5 = Sd3T5(self.t5xxl) |
| 176 | + |
| 177 | + self.weights_loaded = False |
| 178 | + |
| 179 | + def forward(self, prompts: list[str]): |
| 180 | + lg_out, vector_out = self.model_lg(prompts) |
| 181 | + |
| 182 | + token_count = lg_out.shape[1] |
| 183 | + |
| 184 | + t5_out = self.model_t5(prompts, token_count=token_count) |
| 185 | + lgt_out = torch.cat([lg_out, t5_out], dim=-2) |
| 186 | + |
| 187 | + return { |
| 188 | + 'crossattn': lgt_out, |
| 189 | + 'vector': vector_out, |
| 190 | + } |
| 191 | + |
| 192 | + def load_weights(self): |
| 193 | + if self.weights_loaded: |
| 194 | + return |
| 195 | + |
| 196 | + clip_path = os.path.join(shared.models_path, "CLIP") |
| 197 | + |
| 198 | + clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors") |
| 199 | + with safetensors.safe_open(clip_g_file, framework="pt") as file: |
| 200 | + self.clip_g.transformer.load_state_dict(SafetensorsMapping(file)) |
| 201 | + |
| 202 | + clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors") |
| 203 | + with safetensors.safe_open(clip_l_file, framework="pt") as file: |
| 204 | + self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False) |
| 205 | + |
| 206 | + if self.t5xxl: |
| 207 | + t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors") |
| 208 | + with safetensors.safe_open(t5_file, framework="pt") as file: |
| 209 | + self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False) |
| 210 | + |
| 211 | + self.weights_loaded = True |
| 212 | + |
| 213 | + def encode_embedding_init_text(self, init_text, nvpt): |
| 214 | + return torch.tensor([[0]], device=devices.device) # XXX |
| 215 | + |
| 216 | + def medvram_modules(self): |
| 217 | + return [self.clip_g, self.clip_l, self.t5xxl] |
| 218 | + |
| 219 | + def get_token_count(self, text): |
| 220 | + _, token_count = self.model_lg.process_texts([text]) |
| 221 | + |
| 222 | + return token_count |
| 223 | + |
| 224 | + def get_target_prompt_token_count(self, token_count): |
| 225 | + return self.model_lg.get_target_prompt_token_count(token_count) |
0 commit comments