Skip to content

Commit 6c6ae28

Browse files
committed
send all three of GFPGAN's and codeformer's models to CPU memory instead of just one for #1283
1 parent 556c36b commit 6c6ae28

4 files changed

Lines changed: 41 additions & 11 deletions

File tree

‎modules/codeformer_model.py‎

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,14 @@ def create_models(self):
6969

7070
self.net = net
7171
self.face_helper = face_helper
72-
self.net.to(devices.device_codeformer)
7372

7473
return net, face_helper
7574

75+
def send_model_to(self, device):
76+
self.net.to(device)
77+
self.face_helper.face_det.to(device)
78+
self.face_helper.face_parse.to(device)
79+
7680
def restore(self, np_image, w=None):
7781
np_image = np_image[:, :, ::-1]
7882

@@ -82,6 +86,8 @@ def restore(self, np_image, w=None):
8286
if self.net is None or self.face_helper is None:
8387
return np_image
8488

89+
self.send_model_to(devices.device_codeformer)
90+
8591
self.face_helper.clean_all()
8692
self.face_helper.read_image(np_image)
8793
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
@@ -113,8 +119,10 @@ def restore(self, np_image, w=None):
113119
if original_resolution != restored_img.shape[0:2]:
114120
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
115121

122+
self.face_helper.clean_all()
123+
116124
if shared.opts.face_restoration_unload:
117-
self.net.to(devices.cpu)
125+
self.send_model_to(devices.cpu)
118126

119127
return restored_img
120128

‎modules/devices.py‎

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import contextlib
2+
13
import torch
24

35
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
@@ -57,3 +59,11 @@ def randn_without_seed(shape):
5759

5860
return torch.randn(shape, device=device)
5961

62+
63+
def autocast():
64+
from modules import shared
65+
66+
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
67+
return contextlib.nullcontext()
68+
69+
return torch.autocast("cuda")

‎modules/gfpgan_model.py‎

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,32 @@ def gfpgann():
3737
print("Unable to load gfpgan model!")
3838
return None
3939
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
40-
model.gfpgan.to(shared.device)
4140
loaded_gfpgan_model = model
4241

4342
return model
4443

4544

45+
def send_model_to(model, device):
46+
model.gfpgan.to(device)
47+
model.face_helper.face_det.to(device)
48+
model.face_helper.face_parse.to(device)
49+
50+
4651
def gfpgan_fix_faces(np_image):
4752
model = gfpgann()
4853
if model is None:
4954
return np_image
55+
56+
send_model_to(model, devices.device)
57+
5058
np_image_bgr = np_image[:, :, ::-1]
5159
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
5260
np_image = gfpgan_output_bgr[:, :, ::-1]
5361

62+
model.face_helper.clean_all()
63+
5464
if shared.opts.face_restoration_unload:
55-
model.gfpgan.to(devices.cpu)
65+
send_model_to(model, devices.cpu)
5666

5767
return np_image
5868

‎modules/processing.py‎

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import contextlib
21
import json
32
import math
43
import os
@@ -330,9 +329,8 @@ def infotext(iteration=0, position_in_batch=0):
330329

331330
infotexts = []
332331
output_images = []
333-
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
334-
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
335-
with torch.no_grad(), precision_scope("cuda"), ema_scope():
332+
333+
with torch.no_grad():
336334
p.init(all_prompts, all_seeds, all_subseeds)
337335

338336
if state.job_count == -1:
@@ -351,8 +349,9 @@ def infotext(iteration=0, position_in_batch=0):
351349

352350
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
353351
#c = p.sd_model.get_learned_conditioning(prompts)
354-
uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
355-
c = prompt_parser.get_learned_conditioning(prompts, p.steps)
352+
with devices.autocast():
353+
uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
354+
c = prompt_parser.get_learned_conditioning(prompts, p.steps)
356355

357356
if len(model_hijack.comments) > 0:
358357
for comment in model_hijack.comments:
@@ -361,7 +360,9 @@ def infotext(iteration=0, position_in_batch=0):
361360
if p.n_iter > 1:
362361
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
363362

364-
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
363+
with devices.autocast():
364+
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength).to(devices.dtype)
365+
365366
if state.interrupted:
366367

367368
# if we are interruped, sample returns just noise
@@ -386,6 +387,7 @@ def infotext(iteration=0, position_in_batch=0):
386387
devices.torch_gc()
387388

388389
x_sample = modules.face_restoration.restore_faces(x_sample)
390+
devices.torch_gc()
389391

390392
image = Image.fromarray(x_sample)
391393

0 commit comments

Comments
 (0)