0

I have created ViT as follows:

@keras.utils.register_keras_serializable(package="ViT", name="ViT")
class ViT(keras.Model):

    """
    Vision Transformer (ViT), based on: "Dosovitskiy et al.,
    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"

    """

    def __init__(
        self,
        in_channels: int,
        img_size: Union[Sequence[int], int],
        patch_size: Union[Sequence[int], int],
        hidden_size: int = 768,
        mlp_dim: int = 3072,
        num_layers: int = 12,
        num_heads: int = 12,
        proj_type: str = "conv",
        pos_embed_type: str = "learnable",
        classification: bool = False,
        num_classes: int = 2,
        dropout_rate: float = 0.0,
        spatial_dims: int = 3,
        post_activation: Optional[str] = "Tanh",
        qkv_bias: bool = False,
        save_attn: bool = False,
        **kwargs,
    ) -> None:

        """
        Args:
                in_channels (int): dimension of input channels.
                img_size (Union[Sequence[int], int]): dimension of input image.
                patch_size (Union[Sequence[int], int]): dimension of patch size.
                hidden_size (int, optional): dimension of hidden layer. Defaults to 768.
                mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072.
                num_layers (int, optional): number of transformer blocks. Defaults to 12.
                num_heads (int, optional): number of attention heads. Defaults to 12.
                proj_type (str, optional): patch embedding layer type. Defaults to "conv".
                pos_embed_type (str, optional): position embedding type. Defaults to "learnable".
                classification (bool, optional): bool argument to determine if classification is used. Defaults to False.
                num_classes (int, optional): number of classes if classification is used. Defaults to 2.
                dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
                spatial_dims (int, optional): number of spatial dimensions. Defaults to 3.
                post_activation (str, optional): add a final acivation function to the classification head
                    when `classification` is True. Default to "Tanh" for `layers.Activation("tanh")`.
                qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False.
                save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False.

        """

        super().__init__(**kwargs)

        self.in_channels = in_channels
        self.img_size = img_size
        self.patch_size = patch_size
        self.hidden_size = hidden_size
        self.mlp_dim = mlp_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.proj_type = proj_type
        self.pos_embed_type = pos_embed_type
        self.classification = classification
        self.num_classes = num_classes
        self.dropout_rate = dropout_rate
        self.spatial_dims = spatial_dims
        self.post_activation = post_activation
        self.qkv_bias = qkv_bias
        self.save_attn = save_attn


        if not (0 <= self.dropout_rate <= 1):
            raise ValueError("dropout_rate should be between 0 and 1.")
        if self.hidden_size % self.num_heads != 0:
            raise ValueError("hidden_size must be divisible by num_heads.")

        self.patch_embedding = PatchEmbeddingBlock(
            in_channels=self.in_channels,
            img_size=self.img_size,
            patch_size=self.patch_size,
            hidden_size=self.hidden_size,
            num_heads=self.num_heads,
            proj_type=self.proj_type,
            pos_embed_type=self.pos_embed_type,
            dropout_rate=self.dropout_rate,
            spatial_dims=self.spatial_dims,
        )

        self.blocks = keras.Sequential(

            [
            TransformerBlock(
                hidden_size=self.hidden_size,
                mlp_dim=self.mlp_dim,
                num_heads=self.num_heads,
                dropout_rate=self.dropout_rate,
                qkv_bias=self.qkv_bias,
                save_attn=self.save_attn,
            )
            for _ in range(self.num_layers)
            ],
            name="transformer_blocks"
        )

        self.norm = layers.LayerNormalization(epsilon=1e-6)

        if self.classification:
            self.classification_dense = layers.Dense(self.num_classes, name="classification_dense")
            if self.post_activation == "Tanh":
                self.classification_activation = layers.Activation("tanh", name="classification_activation")
    

    def build(self, input_shape):

            if self.classification:
                self.cls_token = self.add_weight(
                    name="cls_token",
                    shape=(1, 1, self.hidden_size),
                    initializer="zeros",
                    trainable=True,
                )

            super().build(input_shape)

    def call(self, x: tf.Tensor, training: bool = None):
        x = self.patch_embedding(x, training=training)

        if self.classification:
            # Expand CLS token to batch size and prepend
            batch_size = tf.shape(x)[0]
            cls_token = tf.broadcast_to(self.cls_token, [batch_size, 1, self.hidden_size])
            x = tf.concat([cls_token, x], axis=1)

        hidden_states_out = []

        for layer in self.blocks.layers:  # Sequential.layers gives the list of layers
            x = layer(x, training=training)
            hidden_states_out.append(x)
        x = self.norm(x, training = training)

        if self.classification:
             x = self.classification_dense(x[:, 0])  # use CLS token for classification
             if self.post_activation == "Tanh":
                x = self.classification_activation(x)


        #return x, hidden_states_out
        return x



    def get_config(self):
        config = super().get_config()
        # Update the config with the custom layer's parameters
        config.update(
            {
                "in_channels": self.in_channels,
                "img_size": self.img_size,
                "patch_size": self.patch_size,
                "hidden_size": self.hidden_size,
                "mlp_dim": self.mlp_dim,
                "num_layers": self.num_layers,
                "num_heads": self.num_heads,
                "proj_type": self.proj_type,
                "pos_embed_type": self.pos_embed_type,
                "classification": self.classification,
                "num_classes": self.num_classes,
                "dropout_rate": self.dropout_rate,
                "spatial_dims": self.spatial_dims,
                "post_activation": self.post_activation,
                "qkv_bias": self.qkv_bias,
                "save_attn": self.save_attn,


            }
        )
        return config

Every block is tested for serialization and deserialization successfully.

Now if I test ViT with

classification: bool = False
# -------------------------
# Instantiate the ViT block
# -------------------------
inputs = Input(shape=(16, 16, 16, 1))  # e.g., 3D volume with 1 channel
vit_block = ViT(
    in_channels=1,
    img_size=16,
    patch_size=4,
    hidden_size=32,
    mlp_dim=64,
    num_layers=2,
    num_heads=4,
    dropout_rate=0.1
)
outputs = vit_block(inputs)
model = Model(inputs, outputs)

# -------------------------
# Run a forward pass
# -------------------------
x_test = np.random.rand(2, 16, 16, 16, 1).astype(np.float32)
y_original = model.predict(x_test)

# -------------------------
# Save & load
# -------------------------
model.save("vit_block_no_cls.keras")
loaded_model = load_model("vit_block_no_cls.keras")

# -------------------------
# Re-run after loading
# -------------------------
y_loaded = loaded_model.predict(x_test)

# -------------------------
# Check numerical equivalence
# -------------------------
assert np.allclose(y_original, y_loaded, atol=1e-6), \
    " Outputs differ after serialization/deserialization"

print(" ViT without classification head: serialization/deserialization works correctly.")

I get

[58]
5s
# -------------------------
# Instantiate the ViT block
# -------------------------
inputs = Input(shape=(16, 16, 16, 1))  # e.g., 3D volume with 1 channel
vit_block = ViT(
    in_channels=1,
    img_size=16,
    patch_size=4,
    hidden_size=32,
    mlp_dim=64,
…print(" ViT without classification head: serialization/deserialization works correctly.")

1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 684ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 684ms/step
 ViT without classification head: serialization/deserialization works correctly.

But i set

classification=True

# =======================
# 3D ViT test
# =======================
model = ViT(
    in_channels=3,
    img_size=(16, 128, 128),  # (D, H, W)
    patch_size=4,              # patch size along each dimension
    hidden_size=64,
    mlp_dim=128,
    num_layers=4,
    num_heads=4,
    classification=True,
    num_classes=2,
    dropout_rate=0.0,
    spatial_dims=3,            # 3D mode
)

# Build model by calling it once with 5D input
dummy_input = np.random.rand(2, 16, 128, 128, 3).astype(np.float32)  # (B, D, H, W, C)
_ = model(dummy_input, training=False)



# Compile and fit the model for one epoch to ensure all layers are built
model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True))

dummy_labels = np.random.randint(0, 2, size=(2,)).astype(np.float32)
model.fit(dummy_input, dummy_labels, epochs=1)

# Save and reload
model.save("vit_test_3d.keras")
loaded_model = keras.saving.load_model("vit_test_3d.keras")


# Test equivalence
x = np.random.rand(2, 16, 128, 128, 3).astype(np.float32)
y1, h1 = model.predict(x)
y2, h2 = loaded_model.predict(x)

assert np.allclose(y1, y2, atol=1e-6)
for a, b in zip(h1, h2):
    assert np.allclose(a, b, atol=1e-6)

print(" 3D Serialization / deserialization test PASSED")

I get

1/1 ━━━━━━━━━━━━━━━━━━━━ 22s 22s/step - loss: 0.7706
/usr/local/lib/python3.12/dist-packages/keras/src/saving/saving_lib.py:797: UserWarning: Skipping variable loading for optimizer 'adam', because it has 4 variables whereas the saved optimizer has 106 variables. 
  saveable.load_own_variables(weights_store.get(inner_path))
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipython-input-4224097669.py in <cell line: 0>()
     30 # Save and reload
     31 model.save("vit_test_3d.keras")
---> 32 loaded_model = keras.saving.load_model("vit_test_3d.keras")
     33 
     34 

3 frames
/usr/local/lib/python3.12/dist-packages/keras/src/saving/saving_lib.py in _raise_loading_failure(error_msgs, warn_only)
    643         warnings.warn(msg)
    644     else:
--> 645         raise ValueError(msg)
    646 
    647 

ValueError: A total of 24 objects could not be loaded. Example error message for object <Dense name=dense_332, built=False>:

Layer 'dense_332' was never built and thus it doesn't have any variables. However the weights file lists 2 variables for this layer.
In most cases, this error indicates that either:

1. The layer is owned by a parent layer that implements a `build()` method, but calling the parent's `build()` method did NOT create the state of the child layer 'dense_332'. A `build()` method must create ALL state for the layer, including the state of any children layers.

2. You need to implement the `def build_from_config(self, config)` method on layer 'dense_332', to specify how to rebuild it during loading. In this case, you might also want to implement the method that generates the build config at saving time, `def get_build_config(self)`. The method `build_from_config()` is meant to create the state of the layer (i.e. its variables) upon deserialization.

List of objects that could not be loaded:
[<Dense name=dense_332, built=False>, <Dense name=dense_331, built=False>, <Dense name=dense_333, built=False>, <Dense name=dense_334, built=False>, <LayerNormalization name=layer_normalization_184, built=False>, <LayerNormalization name=layer_normalization_185, built=False>, <Dense name=dense_336, built=False>, <Dense name=dense_335, built=False>, <Dense name=dense_337, built=False>, <Dense name=dense_338, built=False>, <LayerNormalization name=layer_normalization_186, built=False>, <LayerNormalization name=layer_normalization_187, built=False>, <Dense name=dense_340, built=False>, <Dense name=dense_339, built=False>, <Dense name=dense_341, built=False>, <Dense name=dense_342, built=False>, <LayerNormalization name=layer_normalization_188, built=False>, <LayerNormalization name=layer_normalization_189, built=False>, <Dense name=dense_344, built=False>, <Dense name=dense_343, built=False>, <Dense name=dense_345, built=False>, <Dense name=dense_346, built=False>, <LayerNormalization name=layer_normalization_190, built=False>, <LayerNormalization name=layer_normalization_191, built=False>]

What I have tried:

  1. Moving classification layer creation from build() to __init__()

  2. Explicitly building nested layers in build()

Why does serialization fail specifically when conditional layers are created?

1 Answer 1

1

Serialization fails here because of Keras load_model triggers the parent build() but doesn't automatically trigger it for its sub-layers. Please fix this by explicitly calling call .build() on every child layer to initialize there variables before loading weights.

def build(self, input_shape):
        self.patch_embedding.build(input_shape)

        blocks_shape = (input_shape[0], None, self.hidden_size)
        self.blocks.build(blocks_shape)
        self.norm.build(blocks_shape)

        if self.classification:
            self.cls_token = self.add_weight(
                name="cls_token",
                shape=(1, 1, self.hidden_size),
                initializer="zeros",
                trainable=True,
            )
            self.classification_dense.build((input_shape[0], self.hidden_size))
            
            if self.post_activation == "Tanh":
                self.classification_activation.build((input_shape[0], self.num_classes))

        super().build(input_shape)
Sign up to request clarification or add additional context in comments.

2 Comments

Thanks @Sagar . One more question please, according to Keras Layers are recursively composable, which means that if we want to assign a layer instance in another layer, then it is recommended creating such sublayers in the __init__() method and leave it to the first __call__() to trigger building their weights. So I followed this recommendation but its not working in my case, so I was wondering why it is not working here ?
That recommendation assume call() will run but load_model skip the call() method entirely, so your sub-layers never get the signal to create variables. You must build them manually because Keras don't execute your forward pass logic during loading.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.