1- import contextlib
21import json
32import math
43import os
@@ -330,9 +329,8 @@ def infotext(iteration=0, position_in_batch=0):
330329
331330 infotexts = []
332331 output_images = []
333- precision_scope = torch .autocast if cmd_opts .precision == "autocast" else contextlib .nullcontext
334- ema_scope = (contextlib .nullcontext if cmd_opts .lowvram else p .sd_model .ema_scope )
335- with torch .no_grad (), precision_scope ("cuda" ), ema_scope ():
332+
333+ with torch .no_grad ():
336334 p .init (all_prompts , all_seeds , all_subseeds )
337335
338336 if state .job_count == - 1 :
@@ -351,8 +349,9 @@ def infotext(iteration=0, position_in_batch=0):
351349
352350 #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
353351 #c = p.sd_model.get_learned_conditioning(prompts)
354- uc = prompt_parser .get_learned_conditioning (len (prompts ) * [p .negative_prompt ], p .steps )
355- c = prompt_parser .get_learned_conditioning (prompts , p .steps )
352+ with devices .autocast ():
353+ uc = prompt_parser .get_learned_conditioning (len (prompts ) * [p .negative_prompt ], p .steps )
354+ c = prompt_parser .get_learned_conditioning (prompts , p .steps )
356355
357356 if len (model_hijack .comments ) > 0 :
358357 for comment in model_hijack .comments :
@@ -361,7 +360,9 @@ def infotext(iteration=0, position_in_batch=0):
361360 if p .n_iter > 1 :
362361 shared .state .job = f"Batch { n + 1 } out of { p .n_iter } "
363362
364- samples_ddim = p .sample (conditioning = c , unconditional_conditioning = uc , seeds = seeds , subseeds = subseeds , subseed_strength = p .subseed_strength )
363+ with devices .autocast ():
364+ samples_ddim = p .sample (conditioning = c , unconditional_conditioning = uc , seeds = seeds , subseeds = subseeds , subseed_strength = p .subseed_strength ).to (devices .dtype )
365+
365366 if state .interrupted :
366367
367368 # if we are interruped, sample returns just noise
@@ -386,6 +387,7 @@ def infotext(iteration=0, position_in_batch=0):
386387 devices .torch_gc ()
387388
388389 x_sample = modules .face_restoration .restore_faces (x_sample )
390+ devices .torch_gc ()
389391
390392 image = Image .fromarray (x_sample )
391393
0 commit comments