@@ -450,7 +450,7 @@ def __init__(self,
450450 self .cond_stage_key = cond_stage_key
451451 try :
452452 self .num_downs = len (first_stage_config .params .ddconfig .ch_mult ) - 1
453- except :
453+ except Exception :
454454 self .num_downs = 0
455455 if not scale_by_std :
456456 self .scale_factor = scale_factor
@@ -877,16 +877,6 @@ def forward(self, x, c, *args, **kwargs):
877877 c = self .q_sample (x_start = c , t = tc , noise = torch .randn_like (c .float ()))
878878 return self .p_losses (x , c , t , * args , ** kwargs )
879879
880- def _rescale_annotations (self , bboxes , crop_coordinates ): # TODO: move to dataset
881- def rescale_bbox (bbox ):
882- x0 = clamp ((bbox [0 ] - crop_coordinates [0 ]) / crop_coordinates [2 ])
883- y0 = clamp ((bbox [1 ] - crop_coordinates [1 ]) / crop_coordinates [3 ])
884- w = min (bbox [2 ] / crop_coordinates [2 ], 1 - x0 )
885- h = min (bbox [3 ] / crop_coordinates [3 ], 1 - y0 )
886- return x0 , y0 , w , h
887-
888- return [rescale_bbox (b ) for b in bboxes ]
889-
890880 def apply_model (self , x_noisy , t , cond , return_ids = False ):
891881
892882 if isinstance (cond , dict ):
@@ -1157,8 +1147,10 @@ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quanti
11571147
11581148 if i % log_every_t == 0 or i == timesteps - 1 :
11591149 intermediates .append (x0_partial )
1160- if callback : callback (i )
1161- if img_callback : img_callback (img , i )
1150+ if callback :
1151+ callback (i )
1152+ if img_callback :
1153+ img_callback (img , i )
11621154 return img , intermediates
11631155
11641156 @torch .no_grad ()
@@ -1205,8 +1197,10 @@ def p_sample_loop(self, cond, shape, return_intermediates=False,
12051197
12061198 if i % log_every_t == 0 or i == timesteps - 1 :
12071199 intermediates .append (img )
1208- if callback : callback (i )
1209- if img_callback : img_callback (img , i )
1200+ if callback :
1201+ callback (i )
1202+ if img_callback :
1203+ img_callback (img , i )
12101204
12111205 if return_intermediates :
12121206 return img , intermediates
@@ -1322,7 +1316,7 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=
13221316
13231317 if inpaint :
13241318 # make a simple center square
1325- b , h , w = z . shape [ 0 ], z .shape [2 ], z .shape [3 ]
1319+ h , w = z .shape [2 ], z .shape [3 ]
13261320 mask = torch .ones (N , h , w ).to (self .device )
13271321 # zeros will be filled in
13281322 mask [:, h // 4 :3 * h // 4 , w // 4 :3 * w // 4 ] = 0.
0 commit comments