1

I wrote a custom TensorFlow model shown below. My issue is that the backbone is not appearing in model.summary() and not in model.trainable_weights().

However I does work if I go into model.backbone.trainable_weights(). This means I have to provide both model.trainable_weights() and model.backbone.trainable_weights() to the optimizer, but that does seem wrong and I suspect there is a deeper issue, like the graph being broken somewhere and TensorFlow not registering properly my entire model.

How to make the backbone layers appear in model.summary() and model.trainable_weights() ?

class Model(tf.keras.Model):
    def __init__(
        self, 
        model_name: str, 
        **kwargs: Any
    ):
        super().__init__(**kwargs)
        self.backbone = TFAutoModel.from_pretrained(model_name, use_safetensors=False)
        self.backbone.trainable = True
        self.pool = Pooling2D()
        self.projection = ProjectionHead()
        self.arc = ArcLayer()

    def call(
        self, 
        inputs: Union[tf.Tensor, Tuple[tf.Tensor, Optional[tf.Tensor]]], 
        training: bool = False
    ) -> tf.Tensor:
        """
        Args:
            inputs: 
                - During Training: Tuple of (images, labels)
                - During Inference: Just images OR Tuple of (images, None)
        """
        
        if isinstance(inputs, (tuple, list)):
            images, labels = inputs[0], inputs[1]
        else:
            images, labels = inputs, None

        outputs = self.backbone(images, training=training) 
        features = self.extract_spatial_features(outputs.last_hidden_state) 
        
        pooled = self.pool(features)
        embeddings = self.projection(pooled, training=training)

        return self.arc(embeddings, labels)

1 Answer 1

0

The issue for used subclassed Keras models is stem from lazy initialization, they don't create any weights until they know the input shape. Since your model has not seen data yet, the backbone weights have not been registered. Please fix this by calling model.build() with the correct input shape. This tells the model what to expect and forces it to initialize all those weights immediately.

model = Model(model_name="google/vit-base-patch16-224")
model.build(input_shape=[(None, 224, 224, 3), (None,)])
model.summary(expand_nested=True)
Sign up to request clarification or add additional context in comments.

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.