I am training a model using Keras + TensorFlow with tf.distribute.MirroredStrategy on a multi-GPU setup. I would like to verify that I am using strategy.scope() correctly.
import time
import logging
import os
import json
import datetime
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
os.environ["KERAS_BACKEND"] = "tensorflow"
#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
#os.environ['TF_CPP_MAX_VLOG_LEVEL'] = '0'
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
from tensorflow.python.profiler import profiler_v2 as profiler
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import keras
from keras import ops
from keras import layers
from keras import mixed_precision
from medicai.models import UNETRPlusPlus
from medicai.metrics import BinaryDiceMetric
from medicai.losses import BinaryDiceCELoss
from medicai.utils.inference import SlidingWindowInference
from medicai.callbacks import SlidingWindowInferenceCallback
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from src.experiment_config import ExperimentConfig
from src.data_pipeline.data_loader import data_loader
class TFCheckpointCallback(keras.callbacks.Callback):
"""Save model + optimizer + epoch using TF checkpointing and save SWI callback best score to a JSON file for restoring."""
def __init__(self, ckpt, ckpt_manager, swi_callback, checkpoint_dir):
super().__init__()
self.ckpt = ckpt
self.ckpt_manager = ckpt_manager
self.swi_callback = swi_callback
self.best_score_file = os.path.join(checkpoint_dir, "swi_best_score.json")
def on_epoch_end(self, epoch, logs=None):
# Update epoch variable and save checkpoint
self.ckpt.epoch.assign_add(1) # increment epoch counter
save_path = self.ckpt_manager.save()
print(f"Saved checkpoint: {save_path} (epoch {int(self.ckpt.epoch.numpy())})")
# Save SWI best score externally
best_score = getattr(self.swi_callback, "best_score", -float("inf"))
with open(self.best_score_file, "w") as f:
json.dump({"best_score": best_score}, f)
print(f"[CheckpointCallback] Saved SWI best score: {best_score}")
class HistorySaverCallback(keras.callbacks.Callback):
"""Saves training history every epoch to CSV and allows resuming."""
def __init__(self, history_file, initial_history=None):
super().__init__()
self.history_file = history_file
self.full_history = initial_history if initial_history else {}
def on_epoch_end(self, epoch, logs=None):
if logs is None:
logs = {}
for k, v in logs.items():
self.full_history.setdefault(k, []).append(v)
# Save updated history
pd.DataFrame(self.full_history).to_csv(self.history_file, index=False)
def get_model(total_device):
model = UNETRPlusPlus(
encoder_name="unetr_plusplus_encoder",
input_shape=ExperimentConfig.input_shape,
num_classes=ExperimentConfig.num_classes,
classifier_activation=None,
)
total_train_samples = 387 # 80% ( approx.) split of the total dataset for train as Unetr
# Compute steps per epoch and total steps
steps_per_epoch = total_train_samples // (ExperimentConfig.batch_size_train * total_device)
print(f"Steps per epoch : {steps_per_epoch}")
total_steps = steps_per_epoch * ExperimentConfig.epochs
# Warmup: 10% of total steps
warmup_steps = int(total_steps * 0.1)
# CosineDecay schedule with warmup
lr_schedule = keras.optimizers.schedules.CosineDecay(
initial_learning_rate=0.01 * ExperimentConfig.lr, # very small starting LR
decay_steps=total_steps - warmup_steps, # decay after warmup
alpha=ExperimentConfig.alpha,
warmup_target=ExperimentConfig.lr,
warmup_steps=warmup_steps
)
model.compile(
optimizer=keras.optimizers.AdamW(
learning_rate=lr_schedule,
weight_decay=ExperimentConfig.weight_decay,
),
loss=BinaryDiceCELoss(
from_logits=True,
dice_weight=1.0,
ce_weight=1.0,
reduction="mean",
num_classes=ExperimentConfig.num_classes,
),
metrics=[
BinaryDiceMetric(
from_logits=True,
ignore_empty=True,
num_classes=ExperimentConfig.num_classes,
name='dice',
),
BinaryDiceMetric(
from_logits=True,
ignore_empty=True,
target_class_ids=[0],
num_classes=ExperimentConfig.num_classes,
name='dice_tc',
),
BinaryDiceMetric(
from_logits=True,
ignore_empty=True,
target_class_ids=[1],
num_classes=ExperimentConfig.num_classes,
name='dice_wt',
),
BinaryDiceMetric(
from_logits=True,
ignore_empty=True,
target_class_ids=[2],
num_classes=ExperimentConfig.num_classes,
name='dice_et',
)
],
)
return model
def get_inference_metric():
swi_callback_metric = BinaryDiceMetric(
from_logits=True,
ignore_empty=True,
num_classes=ExperimentConfig.num_classes,
name='val_dice',
)
return swi_callback_metric
"""def run_sliding_window_inference_per_class_average(model, ds, roi_size, sw_batch_size, overlap, metrics_list):
# Run sliding window inference on a dataset and compute all metrics (average + per class)
for metric in metrics_list:
metric.reset_states()
swi = SlidingWindowInference(
model,
num_classes=metrics_list[0].num_classes,
roi_size=roi_size,
sw_batch_size=sw_batch_size,
overlap=overlap
)
for x, y in ds:
y_pred = swi(x)
for metric in metrics_list:
metric.update_state(ops.convert_to_tensor(y), ops.convert_to_tensor(y_pred))
# Gather results
results = {}
for metric in metrics_list:
results[metric.name] = float(ops.convert_to_numpy(metric.result()))
return results"""
def main():
# reproducibility
keras.utils.set_random_seed(101)
print(
f"keras backend: {keras.config.backend()}\n"
f"keras version: {keras.version()}\n"
f"tensorflow version: {tf.__version__}\n"
)
# get keras backend
keras_backend = keras.config.backend()
strategy = tf.distribute.MirroredStrategy()
total_device = strategy.num_replicas_in_sync
print('Keras backend ', keras_backend)
print('Total device found ', total_device)
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
base_save_path = os.path.join(project_root, "experiments", "msd_brain")
unetrplusplus_path = os.path.join(base_save_path, "unetrplusplus")
os.makedirs(unetrplusplus_path, exist_ok=True)
# Subfolders
logs_path = os.path.join(unetrplusplus_path, "logs")
history_path = os.path.join(unetrplusplus_path, "history")
plots_path = os.path.join(unetrplusplus_path, "plots")
os.makedirs(logs_path, exist_ok=True)
os.makedirs(history_path, exist_ok=True)
os.makedirs(plots_path, exist_ok=True)
# Timestamp
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# Save path for best model weights
save_path = os.path.join(unetrplusplus_path, f"best_model_weights_{timestamp}.weights.h5")
# File for containing the learning history
history_file = os.path.join(history_path, f"training_history.csv")
# Load datasets
tfrecord_pattern = os.path.join(project_root, "data", "msd_brain", "tfrecords", "{}_shard_*.tfrec")
# batch size for training
train_batch = ExperimentConfig.batch_size_train * total_device
train_ds = data_loader(
tfrecord_pattern.format("training"),
batch_size=train_batch,
shuffle=True
)
val_ds = data_loader(
tfrecord_pattern.format("validation"),
batch_size=ExperimentConfig.batch_size_val,
shuffle=False
)
test_ds = data_loader(
tfrecord_pattern.format("test"),
batch_size=ExperimentConfig.batch_size_val,
shuffle=False
)
with strategy.scope():
model = get_model(total_device)
checkpoint_dir = os.path.join(unetrplusplus_path, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)
with strategy.scope():
ckpt = tf.train.Checkpoint(
epoch=tf.Variable(0), # epoch counter — saved as part of checkpoint
optimizer=model.optimizer, # optimizer state
model=model # model weights
)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=3)
# Validation with sliding window callback
swi_callback_metric = get_inference_metric()
# Create SWI callback
swi_callback = SlidingWindowInferenceCallback(
model,
dataset=val_ds,
metrics=swi_callback_metric,
num_classes=ExperimentConfig.num_classes,
interval= ExperimentConfig.sliding_window_interval,
overlap=ExperimentConfig.sliding_window_overlap,
roi_size=(ExperimentConfig.input_shape[0],ExperimentConfig.input_shape[1],ExperimentConfig.input_shape[2]),
sw_batch_size=ExperimentConfig.sw_batch_size * total_device ,
save_path=save_path
)
# TFCheckpointCallback (save model, optimizer, epoch + SWI best score)
tf_ckpt_callback = TFCheckpointCallback(ckpt, ckpt_manager, swi_callback, checkpoint_dir)
# History callback
# Load previous history if exists
if os.path.exists(history_file):
prev_history = pd.read_csv(history_file).to_dict(orient='list')
else:
prev_history = {}
history_callback = HistorySaverCallback(history_file, initial_history=prev_history)
# Resume or start from scratch
if ckpt_manager.latest_checkpoint:
ckpt.restore(ckpt_manager.latest_checkpoint)
initial_epoch = int(ckpt.epoch.numpy())
print(f"[Resume] Restored checkpoint: starting from epoch {initial_epoch}")
# Restore SWI best score
best_score_file = os.path.join(checkpoint_dir, "swi_best_score.json")
if os.path.exists(best_score_file):
with open(best_score_file, "r") as f:
swi_callback.best_score = json.load(f).get("best_score", -float("inf"))
print(f"[Resume] Restored SWI best validation score: {swi_callback.best_score}")
else:
print(f"[Resume] Couldn't Restore SWI best validation score")
else:
initial_epoch = 0
print("[Resume] No checkpoint found. Starting from scratch.")
print(f"Model size: {model.count_params() / 1e6:.2f} M")
start_time = time.time()
with strategy.scope():
history = model.fit(
train_ds,
epochs=ExperimentConfig.epochs,
initial_epoch=initial_epoch,
callbacks=[
swi_callback,
tf_ckpt_callback,
history_callback
])
end_time = time.time()
training_time = end_time - start_time
print(f"Total training time (seconds): {training_time:.2f}")
# Save history to CSV
full_history = history_callback.full_history
# Save CSV
pd.DataFrame(full_history).to_csv(history_file, index=False)
# Plot loss
plt.figure(figsize=(10, 5))
plt.plot(full_history['loss'], label='train_loss')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.legend()
plt.grid()
plt.savefig(os.path.join(plots_path, f"loss_curve_{timestamp}.png"))
plt.close()
# Plot average Dice
if 'dice' in full_history:
plt.figure(figsize=(10, 5))
plt.plot(full_history['dice'], label='train_dice')
plt.xlabel("Epoch")
plt.ylabel("Average Dice")
plt.title("Training Average Dice")
plt.legend()
plt.grid()
plt.savefig(os.path.join(plots_path, f"dice_curve_{timestamp}.png"))
plt.close()
print("Training and saving plots finished successfully.")
if __name__ == "__main__":
main()
To avoid mistakes, I currently put almost everything related to training inside strategy.scope(), including some objects where I am not sure whether they create TensorFlow variables or not.
Specifically, inside the scope I create:
The model
The optimizer
The loss
All training metrics
A metric used by a custom validation callback
Checkpoint objects (
tf.train.Checkpoint,CheckpointManager)Callbacks that reference the model and metrics
Datasets, paths, logging, and pure Python utilities are created outside the scope.
My current understadining is:
Objects that create TensorFlow variables (model, optimizer, metrics) must be created inside
strategy.scope().Objects that own or update metrics (e.g., custom callbacks that track validation scores) should also be created inside the scope.
Checkpoint objects should be created inside the scope so they correctly track distributed variables.
Dataset creation does not need to be inside the scope.
So my maun concern is that there are some objects where I am not 100% sure whether they create shared TensorFlow variables internally (for example, custom callbacks or utility classes that accept metrics or models).
Because of this uncertainty, I chose what seems like the safest option, which is putting everything I am unsure about inside strategy.scope().
So my question is: Is my code correct and safe for distributed multi-GPU training, or are there any mistakes in how I am using tf.distribute.MirroredStrategy and strategy.scope()?
strategy.scope()is the correct and standard approach. One minor suggestion is to use.batch(batch_size, drop_remainder=True)so the last uneven batch does nott crash the GPUs, and add.prefetch(tf.data.AUTOTUNE)so your training does not wait for data.