@@ -53,19 +53,46 @@ def first_stage_model_decode_wrap(z):
5353 send_me_to_gpu (first_stage_model , None )
5454 return first_stage_model_decode (z )
5555
56- # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
57- if hasattr (sd_model .cond_stage_model , 'model' ):
58- sd_model .cond_stage_model .transformer = sd_model .cond_stage_model .model
59-
60- # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
61- # send the model to GPU. Then put modules back. the modules will be in CPU.
62- stored = sd_model .cond_stage_model .transformer , sd_model .first_stage_model , getattr (sd_model , 'depth_model' , None ), getattr (sd_model , 'embedder' , None ), sd_model .model
63- sd_model .cond_stage_model .transformer , sd_model .first_stage_model , sd_model .depth_model , sd_model .embedder , sd_model .model = None , None , None , None , None
56+ to_remain_in_cpu = [
57+ (sd_model , 'first_stage_model' ),
58+ (sd_model , 'depth_model' ),
59+ (sd_model , 'embedder' ),
60+ (sd_model , 'model' ),
61+ (sd_model , 'embedder' ),
62+ ]
63+
64+ is_sdxl = hasattr (sd_model , 'conditioner' )
65+ is_sd2 = not is_sdxl and hasattr (sd_model .cond_stage_model , 'model' )
66+
67+ if is_sdxl :
68+ to_remain_in_cpu .append ((sd_model , 'conditioner' ))
69+ elif is_sd2 :
70+ to_remain_in_cpu .append ((sd_model .cond_stage_model , 'model' ))
71+ else :
72+ to_remain_in_cpu .append ((sd_model .cond_stage_model , 'transformer' ))
73+
74+ # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
75+ stored = []
76+ for obj , field in to_remain_in_cpu :
77+ module = getattr (obj , field , None )
78+ stored .append (module )
79+ setattr (obj , field , None )
80+
81+ # send the model to GPU.
6482 sd_model .to (devices .device )
65- sd_model .cond_stage_model .transformer , sd_model .first_stage_model , sd_model .depth_model , sd_model .embedder , sd_model .model = stored
83+
84+ # put modules back. the modules will be in CPU.
85+ for (obj , field ), module in zip (to_remain_in_cpu , stored ):
86+ setattr (obj , field , module )
6687
6788 # register hooks for those the first three models
68- sd_model .cond_stage_model .transformer .register_forward_pre_hook (send_me_to_gpu )
89+ if is_sdxl :
90+ sd_model .conditioner .register_forward_pre_hook (send_me_to_gpu )
91+ elif is_sd2 :
92+ sd_model .cond_stage_model .model .register_forward_pre_hook (send_me_to_gpu )
93+ else :
94+ sd_model .cond_stage_model .transformer .register_forward_pre_hook (send_me_to_gpu )
95+
6996 sd_model .first_stage_model .register_forward_pre_hook (send_me_to_gpu )
7097 sd_model .first_stage_model .encode = first_stage_model_encode_wrap
7198 sd_model .first_stage_model .decode = first_stage_model_decode_wrap
@@ -75,10 +102,6 @@ def first_stage_model_decode_wrap(z):
75102 sd_model .embedder .register_forward_pre_hook (send_me_to_gpu )
76103 parents [sd_model .cond_stage_model .transformer ] = sd_model .cond_stage_model
77104
78- if hasattr (sd_model .cond_stage_model , 'model' ):
79- sd_model .cond_stage_model .model = sd_model .cond_stage_model .transformer
80- del sd_model .cond_stage_model .transformer
81-
82105 if use_medvram :
83106 sd_model .model .register_forward_pre_hook (send_me_to_gpu )
84107 else :
0 commit comments