I have created ViT as follows:
@keras.utils.register_keras_serializable(package="ViT", name="ViT")
class ViT(keras.Model):
"""
Vision Transformer (ViT), based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
"""
def __init__(
self,
in_channels: int,
img_size: Union[Sequence[int], int],
patch_size: Union[Sequence[int], int],
hidden_size: int = 768,
mlp_dim: int = 3072,
num_layers: int = 12,
num_heads: int = 12,
proj_type: str = "conv",
pos_embed_type: str = "learnable",
classification: bool = False,
num_classes: int = 2,
dropout_rate: float = 0.0,
spatial_dims: int = 3,
post_activation: Optional[str] = "Tanh",
qkv_bias: bool = False,
save_attn: bool = False,
**kwargs,
) -> None:
"""
Args:
in_channels (int): dimension of input channels.
img_size (Union[Sequence[int], int]): dimension of input image.
patch_size (Union[Sequence[int], int]): dimension of patch size.
hidden_size (int, optional): dimension of hidden layer. Defaults to 768.
mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072.
num_layers (int, optional): number of transformer blocks. Defaults to 12.
num_heads (int, optional): number of attention heads. Defaults to 12.
proj_type (str, optional): patch embedding layer type. Defaults to "conv".
pos_embed_type (str, optional): position embedding type. Defaults to "learnable".
classification (bool, optional): bool argument to determine if classification is used. Defaults to False.
num_classes (int, optional): number of classes if classification is used. Defaults to 2.
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
spatial_dims (int, optional): number of spatial dimensions. Defaults to 3.
post_activation (str, optional): add a final acivation function to the classification head
when `classification` is True. Default to "Tanh" for `layers.Activation("tanh")`.
qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False.
save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False.
"""
super().__init__(**kwargs)
self.in_channels = in_channels
self.img_size = img_size
self.patch_size = patch_size
self.hidden_size = hidden_size
self.mlp_dim = mlp_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.proj_type = proj_type
self.pos_embed_type = pos_embed_type
self.classification = classification
self.num_classes = num_classes
self.dropout_rate = dropout_rate
self.spatial_dims = spatial_dims
self.post_activation = post_activation
self.qkv_bias = qkv_bias
self.save_attn = save_attn
if not (0 <= self.dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")
if self.hidden_size % self.num_heads != 0:
raise ValueError("hidden_size must be divisible by num_heads.")
self.patch_embedding = PatchEmbeddingBlock(
in_channels=self.in_channels,
img_size=self.img_size,
patch_size=self.patch_size,
hidden_size=self.hidden_size,
num_heads=self.num_heads,
proj_type=self.proj_type,
pos_embed_type=self.pos_embed_type,
dropout_rate=self.dropout_rate,
spatial_dims=self.spatial_dims,
)
self.blocks = keras.Sequential(
[
TransformerBlock(
hidden_size=self.hidden_size,
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
dropout_rate=self.dropout_rate,
qkv_bias=self.qkv_bias,
save_attn=self.save_attn,
)
for _ in range(self.num_layers)
],
name="transformer_blocks"
)
self.norm = layers.LayerNormalization(epsilon=1e-6)
if self.classification:
self.classification_dense = layers.Dense(self.num_classes, name="classification_dense")
if self.post_activation == "Tanh":
self.classification_activation = layers.Activation("tanh", name="classification_activation")
def build(self, input_shape):
if self.classification:
self.cls_token = self.add_weight(
name="cls_token",
shape=(1, 1, self.hidden_size),
initializer="zeros",
trainable=True,
)
super().build(input_shape)
def call(self, x: tf.Tensor, training: bool = None):
x = self.patch_embedding(x, training=training)
if self.classification:
# Expand CLS token to batch size and prepend
batch_size = tf.shape(x)[0]
cls_token = tf.broadcast_to(self.cls_token, [batch_size, 1, self.hidden_size])
x = tf.concat([cls_token, x], axis=1)
hidden_states_out = []
for layer in self.blocks.layers: # Sequential.layers gives the list of layers
x = layer(x, training=training)
hidden_states_out.append(x)
x = self.norm(x, training = training)
if self.classification:
x = self.classification_dense(x[:, 0]) # use CLS token for classification
if self.post_activation == "Tanh":
x = self.classification_activation(x)
#return x, hidden_states_out
return x
def get_config(self):
config = super().get_config()
# Update the config with the custom layer's parameters
config.update(
{
"in_channels": self.in_channels,
"img_size": self.img_size,
"patch_size": self.patch_size,
"hidden_size": self.hidden_size,
"mlp_dim": self.mlp_dim,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"proj_type": self.proj_type,
"pos_embed_type": self.pos_embed_type,
"classification": self.classification,
"num_classes": self.num_classes,
"dropout_rate": self.dropout_rate,
"spatial_dims": self.spatial_dims,
"post_activation": self.post_activation,
"qkv_bias": self.qkv_bias,
"save_attn": self.save_attn,
}
)
return config
Every block is tested for serialization and deserialization successfully.
Now if I test ViT with
classification: bool = False
# -------------------------
# Instantiate the ViT block
# -------------------------
inputs = Input(shape=(16, 16, 16, 1)) # e.g., 3D volume with 1 channel
vit_block = ViT(
in_channels=1,
img_size=16,
patch_size=4,
hidden_size=32,
mlp_dim=64,
num_layers=2,
num_heads=4,
dropout_rate=0.1
)
outputs = vit_block(inputs)
model = Model(inputs, outputs)
# -------------------------
# Run a forward pass
# -------------------------
x_test = np.random.rand(2, 16, 16, 16, 1).astype(np.float32)
y_original = model.predict(x_test)
# -------------------------
# Save & load
# -------------------------
model.save("vit_block_no_cls.keras")
loaded_model = load_model("vit_block_no_cls.keras")
# -------------------------
# Re-run after loading
# -------------------------
y_loaded = loaded_model.predict(x_test)
# -------------------------
# Check numerical equivalence
# -------------------------
assert np.allclose(y_original, y_loaded, atol=1e-6), \
" Outputs differ after serialization/deserialization"
print(" ViT without classification head: serialization/deserialization works correctly.")
I get
[58]
5s
# -------------------------
# Instantiate the ViT block
# -------------------------
inputs = Input(shape=(16, 16, 16, 1)) # e.g., 3D volume with 1 channel
vit_block = ViT(
in_channels=1,
img_size=16,
patch_size=4,
hidden_size=32,
mlp_dim=64,
…print(" ViT without classification head: serialization/deserialization works correctly.")
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 684ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 684ms/step
ViT without classification head: serialization/deserialization works correctly.
But i set
classification=True
# =======================
# 3D ViT test
# =======================
model = ViT(
in_channels=3,
img_size=(16, 128, 128), # (D, H, W)
patch_size=4, # patch size along each dimension
hidden_size=64,
mlp_dim=128,
num_layers=4,
num_heads=4,
classification=True,
num_classes=2,
dropout_rate=0.0,
spatial_dims=3, # 3D mode
)
# Build model by calling it once with 5D input
dummy_input = np.random.rand(2, 16, 128, 128, 3).astype(np.float32) # (B, D, H, W, C)
_ = model(dummy_input, training=False)
# Compile and fit the model for one epoch to ensure all layers are built
model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True))
dummy_labels = np.random.randint(0, 2, size=(2,)).astype(np.float32)
model.fit(dummy_input, dummy_labels, epochs=1)
# Save and reload
model.save("vit_test_3d.keras")
loaded_model = keras.saving.load_model("vit_test_3d.keras")
# Test equivalence
x = np.random.rand(2, 16, 128, 128, 3).astype(np.float32)
y1, h1 = model.predict(x)
y2, h2 = loaded_model.predict(x)
assert np.allclose(y1, y2, atol=1e-6)
for a, b in zip(h1, h2):
assert np.allclose(a, b, atol=1e-6)
print(" 3D Serialization / deserialization test PASSED")
I get
1/1 ━━━━━━━━━━━━━━━━━━━━ 22s 22s/step - loss: 0.7706
/usr/local/lib/python3.12/dist-packages/keras/src/saving/saving_lib.py:797: UserWarning: Skipping variable loading for optimizer 'adam', because it has 4 variables whereas the saved optimizer has 106 variables.
saveable.load_own_variables(weights_store.get(inner_path))
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipython-input-4224097669.py in <cell line: 0>()
30 # Save and reload
31 model.save("vit_test_3d.keras")
---> 32 loaded_model = keras.saving.load_model("vit_test_3d.keras")
33
34
3 frames
/usr/local/lib/python3.12/dist-packages/keras/src/saving/saving_lib.py in _raise_loading_failure(error_msgs, warn_only)
643 warnings.warn(msg)
644 else:
--> 645 raise ValueError(msg)
646
647
ValueError: A total of 24 objects could not be loaded. Example error message for object <Dense name=dense_332, built=False>:
Layer 'dense_332' was never built and thus it doesn't have any variables. However the weights file lists 2 variables for this layer.
In most cases, this error indicates that either:
1. The layer is owned by a parent layer that implements a `build()` method, but calling the parent's `build()` method did NOT create the state of the child layer 'dense_332'. A `build()` method must create ALL state for the layer, including the state of any children layers.
2. You need to implement the `def build_from_config(self, config)` method on layer 'dense_332', to specify how to rebuild it during loading. In this case, you might also want to implement the method that generates the build config at saving time, `def get_build_config(self)`. The method `build_from_config()` is meant to create the state of the layer (i.e. its variables) upon deserialization.
List of objects that could not be loaded:
[<Dense name=dense_332, built=False>, <Dense name=dense_331, built=False>, <Dense name=dense_333, built=False>, <Dense name=dense_334, built=False>, <LayerNormalization name=layer_normalization_184, built=False>, <LayerNormalization name=layer_normalization_185, built=False>, <Dense name=dense_336, built=False>, <Dense name=dense_335, built=False>, <Dense name=dense_337, built=False>, <Dense name=dense_338, built=False>, <LayerNormalization name=layer_normalization_186, built=False>, <LayerNormalization name=layer_normalization_187, built=False>, <Dense name=dense_340, built=False>, <Dense name=dense_339, built=False>, <Dense name=dense_341, built=False>, <Dense name=dense_342, built=False>, <LayerNormalization name=layer_normalization_188, built=False>, <LayerNormalization name=layer_normalization_189, built=False>, <Dense name=dense_344, built=False>, <Dense name=dense_343, built=False>, <Dense name=dense_345, built=False>, <Dense name=dense_346, built=False>, <LayerNormalization name=layer_normalization_190, built=False>, <LayerNormalization name=layer_normalization_191, built=False>]
What I have tried:
Moving classification layer creation from
build()to__init__()Explicitly building nested layers in
build()
Why does serialization fail specifically when conditional layers are created?