0

After much effort, I managed to build a tensorflow 2 implementation of an existing pytorch style-transfer project. Then I wanted to get all the nice extra features that are available through Keras standard learning, e.g. model.fit().

But the same model fails when learning through model.fit(). The model seems to learn the content features, but is unable to learn style features. This is the diagram of the model in quesion:

enter image description here

def vgg_layers19(content_layers, style_layers, input_shape=(256,256,3)):
  """ creates a VGG model that returns output values for the given layers
  see: https://keras.io/applications/#extract-features-from-an-arbitrary-intermediate-layer-with-vgg19
  Returns: 
    function(x, preprocess=True):
      Args: 
        x: image tuple/ndarray h,w,c(RGB), domain=(0.,255.)
      Returns:
        a tuple of lists, ([content_features], [style_features])
  usage:
    (content_features, style_features) = vgg_layers16(content_layers, style_layers)(x_train)
  """
  preprocessingFn = tf.keras.applications.vgg19.preprocess_input
  base_model = tf.keras.applications.VGG19(include_top=False, weights='imagenet', input_shape=input_shape)
  base_model.trainable = False
  content_features = [base_model.get_layer(name).output for name in content_layers]
  style_features = [base_model.get_layer(name).output for name in style_layers]
  output_features = content_features + style_features

  model = Model( inputs=base_model.input, outputs=output_features, name="vgg_layers")
  model.trainable = False

  def _get_features(x, preprocess=True):
    """
    Args:
      x: expecting tensor, domain=255. hwcRGB
    """
    if preprocess and callable(preprocessingFn): 
      x = preprocessingFn(x)
    output = model(x) # call as tf.keras.Layer()
    return ( output[:len(content_layers)], output[len(content_layers):] )

  return _get_features 



class VGG_Features():
""" get content and style features from VGG model """
  def __init__(self, loss_model, style_image=None, target_style_gram=None):
    self.loss_model = loss_model
    if style_image is not None:
      assert style_image.shape == (256,256,3), "ERROR: loss_model expecting input_shape=(256,256,3), got {}".format(style_image.shape)
      self.style_image = style_image
      self.target_style_gram = VGG_Features.get_style_gram(self.loss_model, self.style_image)
    if target_style_gram is not None:
      self.target_style_gram = target_style_gram

  @staticmethod
  def get_style_gram(vgg_features_model, style_image):
    style_batch = tf.repeat( style_image[tf.newaxis,...], repeats=_batch_size, axis=0)
    # show([style_image], w=128, domain=(0.,255.) )

    # B, H, W, C = style_batch.shape
    (_, style_features) = vgg_features_model( style_batch , preprocess=True ) # hwcRGB
    target_style_gram = [ fnstf_utils.gram(value)  for value in style_features ]  # list
    return target_style_gram  

  def __call__(self, input_batch):
    content_features, style_features = self.loss_model( input_batch, preprocess=True )
    style_gram = tuple(fnstf_utils.gram(value)  for value in style_features)  # tuple(<generator>)
    return (content_features[0],) + style_gram  # tuple = tuple + tuple




class TransformerNetwork_VGG(tf.keras.Model):
  def __init__(self, transformer=transformer, vgg_features=vgg_features):
    super(TransformerNetwork_VGG, self).__init__()
    self.transformer = transformer 
    # type: tf.keras.models.Model
    # input_shapes:  (None, 256,256,3)
    # output_shapes: (None, 256,256,3)


    style_model = {
       'content_layers':['block5_conv2'],
       'style_layers': ['block1_conv1',
                  'block2_conv1',
                  'block3_conv1', 
                  'block4_conv1', 
                  'block5_conv1']
    }
    vgg_model = vgg_layers19( style_model['content_layers'], style_model['style_layers'] )

    self.vgg_features = VGG_Features(vgg_model, style_image=style_image, batch_size=batch_size) 

    # input_shapes:  (None, 256,256,3)
    # output_shapes: [(None, 16, 16, 512),  (None, 64, 64), (None, 128, 128), (None, 256, 256), (None, 512, 512), (None, 512, 512)]
    #                [ content_loss,        style_loss_1, style_loss_2, style_loss_3, style_loss_4, style_loss_5 ]


  def call(self, inputs):
    x = inputs                # shape=(None, 256,256,3)

    # shape=(None, 256,256,3)
    generated_image = self.transformer(x)                    

    # shape=[(None, 16, 16, 512),  (None, 64, 64), (None, 128, 128), (None, 256, 256), (None, 512, 512), (None, 512, 512)]
    vgg_feature_losses = self.vgg(generated_image)           

    return vgg_feature_losses       # tuple(content1, style1, style2, style3, style4, style5)

Style Image style image

FEATURE_WEIGHTS= [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]

GradientTape learning

With the tf.GradientTape() loop, I'm manually handling the multiple outputs, e.g. tuple of 6 tensors, from TransformerNetwork_VGG(x_train). This method learns correctly.

  @tf.function()
  def train_step(x_train, y_true, loss_weights=None, log_freq=10):
    with tf.GradientTape() as tape:
      y_pred = TransformerNetwork_VGG(x_train)
      generated_content_features = y_pred[:1]
      generated_style_gram = y_pred[1:]


      y_true = TransformerNetwork_VGG.vgg(x_train)
      target_content_features = y_true[:1]
      target_style_gram = TransformerNetwork_VGG.vgg.target_style_gram

      content_loss = get_MEAN_mse_loss(target_content_features, generated_content_features, weights)
      style_loss = tuple(get_MEAN_mse_loss(x,y)*w for x,y,w in zip(target_style_gram, generated_style_gram, weights))

      total_loss = content_loss + = tf.reduce_sum(style_loss)
      TransformerNetwork = TransformerNetwork_VGG.transformer
      grads = tape.gradient(total_loss, TransformerNetwork.trainable_weights)
      optimizer.apply_gradients(zip(grads, TransformerNetwork.trainable_weights))
# GradientTape epoch=5: 
# losses:             [   6078.71         70.23  4495.13 13817.65 88217.99    48.36]

gradient tape

model.fit() learning

With tf.keras.models.Model.fit(), the multiple outputs, e.g. tuple of 6 tensors, are fed to the loss function individually as loss(y_pred, y_true) and then multipled by the correct weight on reduction. This method does learn to approximate the content_image, but does not learn to minimize the style losses! II cannot figure out why.

  history = TransformerNetwork_VGG.fit(
    x=train_dataset.repeat(NUM_EPOCHS),
    epochs=NUM_EPOCHS,
    steps_per_epoch=NUM_BATCHES,
    callbacks=callbacks,
  )
# model.fit() epoch=5: 
# losses:             [  4661.08       219.95   6959.01   4897.39 209201.16     84.68]]

model-fit

50 epochs, with boosted style_weights, FEATURE_WEIGHTS= [ 0.1854, 1605.23, 25.08, 8.16, 1.28, 2330.79] # boost style loss x100

model-fit after 50

step=50, losses=[269899.45 337.5 69617.7 38424.96 9192.36 85903.44 66423.51]

check mse losses * weights

I tested my model with losses and weights fixed as follows * FEATURE_WEIGHTS = SEQ = [1.,2.,3.,4.,5.,6.,] * MSELoss(y_true, y_pred) == tf.ones() of equal shape and confirmed that model.fit() is handling multiple output losses * weights correctly

losses as ones

I've checked everything I can think of, but I cannot figure out how to make the model learn correctly with model.fit(). What am I missing??

The full notebook is available here: https://colab.research.google.com/github/mixuala/fast_neural_style_pytorch/blob/master/notebook/%5BSO%5D_FastStyleTransfer.ipynb

8
  • Can you please provide the definitiojn for TransformerNetwork_VGG.vgg()? Commented Mar 6, 2020 at 20:25
  • But from first impression, model.fit and GradientTape are attempting to fit two different models. model.fit trains TransformerNetwork_VGG, while GradientTape trains TransformerNetwork_VGG.transformer. Commented Mar 6, 2020 at 20:28
  • I seemed to have broken things when I "refactored" the code for stackoverflow and now nothing learns correctly. Still trying to earlier revisions to see what worked. But I'll add the most recent code for TransformerNetwork_VGG to the question Commented Mar 7, 2020 at 6:34
  • But the original idea of TransformerNetwork_VGG was to set TransformerNetwork_VGG.vgg.loss_model.trainable=False to fix the VGG19 weights. Otherwise, with model.fit() I was having problems handling both the loss from the multiple outputs of VGG_Features() and the content_loss from the generated image after the transformer. the model.fit() loss function only takes (y_true,y_pred) as args. Commented Mar 7, 2020 at 6:52
  • this stuff stinks. I "cleaned" up my notebook before I posted this question, only to discover that it no longer worked even with tf.GradientTape. I spent a week building unit tests for every piece of code. Finally, after carefully comparing loss values between an old notebook and the current one (MANUALLY), I discovered that my target_style_gram was calculated using a style_image with domain=(0,1) but my input images were fed with domain=(0,255). doh!!! it now works. Commented Mar 11, 2020 at 4:03

0

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.