-1

continued from before: Why does my model work with `tf.GradientTape()` but fail when using `keras.models.Model.fit()`

I'm working on replicating the perceptual style transfer model as diagrammed below: enter image description here

I finally have my model learning as expected on 1000 images from the COCO2014 dataset. But then I tried to run 2 epochs of the entire dataset, with 20695 batches per epoch (as per the research paper.) It starts learning very quickly, but after about 3700 steps it just mysteriously fails. (saving 1 generated image every 100 batches, most recent on the left)

enter image description here

The predictions I make with the saved checkpoints show similar results:

enter image description here

looking as the losses near the point of failure, I see:

# output_1 is content_loss
# output_2-6 are gram matrix style_loss values
 [batch:3400/20695] - loss: 953168.7218 - output_1_loss: 123929.1953 - output_2_loss: 55090.2109 - output_3_loss: 168500.2344 - output_4_loss: 139039.1250 - output_5_loss: 355890.0312 - output_6_loss: 110718.5781

 [batch:3500/20695] - loss: 935344.0219 - output_1_loss: 124042.5938 - output_2_loss: 53807.3516 - output_3_loss: 164373.4844 - output_4_loss: 135753.5938 - output_5_loss: 348085.6250 - output_6_loss: 109280.0469

 [batch:3600/20695] - loss: 918017.2146 - output_1_loss: 124055.9922 - output_2_loss: 52535.9062 - output_3_loss: 160401.0469 - output_4_loss: 132601.0156 - output_5_loss: 340561.5938 - output_6_loss: 107860.3047

 [batch:3700/20695] - loss: 901454.0553 - output_1_loss: 124096.1328 - output_2_loss: 51326.8672 - output_3_loss: 156607.0312 - output_4_loss: 129584.2578 - output_5_loss: 333345.5312 - output_6_loss: 106493.0781

 [batch:3750/20695] - loss: 893397.4667 - output_1_loss: 124108.4531 - output_2_loss: 50735.1992 - output_3_loss: 154768.8281 - output_4_loss: 128128.1953 - output_5_loss: 329850.2188 - output_6_loss: 105805.6250

# total loss increases after batch=3750. WHY???

 [batch:3800/20695] - loss: 1044768.7239 - output_1_loss: 123897.2188 - output_2_loss: 101063.2812 - output_3_loss: 200778.2812 - output_4_loss: 141584.6875 - output_5_loss: 370377.5000 - output_6_loss: 107066.7812

 [batch:3900/20695] - loss: 1479362.4735 - output_1_loss: 123050.9766 - output_2_loss: 200276.5156 - output_3_loss: 356414.2188 - output_4_loss: 185420.0781 - output_5_loss: 502506.7500 - output_6_loss: 111692.8750 

I can't begin to think of how to debug this problem. Once it "works", should the model continue to work? It seems like some kind of buffer overflow, but I have no idea how to find it. Any ideas?

the full colab notebook/repo can be found here: https://colab.research.google.com/github/mixuala/fast_neural_style_pytorch/blob/master/notebook/%5BSO%5D_Coco14_FastStyleTransfer.ipynb

2 Answers 2

0

You can try two classical methods here:

  1. Learning rate decay. Decay per 100 batches or so, instead of every epoch.

  2. Gradient Clipping. Clip gradients between specified values. For the generative networks I've used earlier, gradients between -5 and 5 perform well. If you think the network is learning very slowly, you can increase the range.

Sign up to request clarification or add additional context in comments.

2 Comments

it actually learns, then it suddenly fails. maybe that's a sign of exploding gradients--I'm still trying to learn the intuition. But for now, I excluded the first 4K batches and it seems to continue to learn just fine. maybe it's a problem with an input tensor with domain out of bounds, or an alpha channel? I'm using tf.image.decode_jpeg() in my preprocessing code
[SOLVED]: it was an all black input image that caused a divide by zero error in the domain scaling code.
0

I found a saturated white image, RGB=255, that caused the model to become unstable. appeared in batch=3696, batch_size=4. when I skipped that batch, everything worked fine.

I know that there was some monitoring code that got a divide by zero error when trying to normalized the domain of the image. But I'm not sure if that error is connected to the model destabilization. The generated image from the model was all black

enter image description here

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.