0

I am training a model using Keras + TensorFlow with tf.distribute.MirroredStrategy on a multi-GPU setup. I would like to verify that I am using strategy.scope() correctly.

import time 
import logging
import os
import json
import datetime
os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
os.environ["KERAS_BACKEND"] = "tensorflow" 
#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
#os.environ['TF_CPP_MAX_VLOG_LEVEL'] = '0'
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
from tensorflow.python.profiler import profiler_v2 as profiler        
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import keras
from keras import ops
from keras import layers
from keras import mixed_precision
from medicai.models import UNETRPlusPlus
from medicai.metrics import BinaryDiceMetric
from medicai.losses import BinaryDiceCELoss
from medicai.utils.inference import SlidingWindowInference
from medicai.callbacks import SlidingWindowInferenceCallback
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from src.experiment_config import ExperimentConfig
from src.data_pipeline.data_loader import data_loader
 



class TFCheckpointCallback(keras.callbacks.Callback):
    """Save model + optimizer + epoch using TF checkpointing and save SWI callback best score to a JSON file for restoring."""
    def __init__(self, ckpt, ckpt_manager, swi_callback, checkpoint_dir):
        super().__init__()
        self.ckpt = ckpt
        self.ckpt_manager = ckpt_manager
        self.swi_callback = swi_callback
        self.best_score_file = os.path.join(checkpoint_dir, "swi_best_score.json")


    def on_epoch_end(self, epoch, logs=None):
        # Update epoch variable and save checkpoint
        self.ckpt.epoch.assign_add(1)   # increment epoch counter
        save_path = self.ckpt_manager.save()
        print(f"Saved checkpoint: {save_path} (epoch {int(self.ckpt.epoch.numpy())})")
        
        # Save SWI best score externally
        best_score = getattr(self.swi_callback, "best_score", -float("inf"))
        with open(self.best_score_file, "w") as f:
            json.dump({"best_score": best_score}, f)
        print(f"[CheckpointCallback] Saved SWI best score: {best_score}")
        
        
class HistorySaverCallback(keras.callbacks.Callback):
    """Saves training history every epoch to CSV and allows resuming."""
    def __init__(self, history_file, initial_history=None):
        super().__init__()
        self.history_file = history_file
        self.full_history = initial_history if initial_history else {}

    def on_epoch_end(self, epoch, logs=None):
        if logs is None:
            logs = {}
        for k, v in logs.items():
            self.full_history.setdefault(k, []).append(v)

        # Save updated history
        pd.DataFrame(self.full_history).to_csv(self.history_file, index=False)        





def get_model(total_device):
    
    model = UNETRPlusPlus(
        encoder_name="unetr_plusplus_encoder",
        input_shape=ExperimentConfig.input_shape,
        num_classes=ExperimentConfig.num_classes,
        classifier_activation=None,
    )
    
    
    total_train_samples = 387 # 80% ( approx.) split of the total dataset for train  as Unetr
    
    # Compute steps per epoch and total steps
    steps_per_epoch = total_train_samples // (ExperimentConfig.batch_size_train * total_device)
    print(f"Steps per epoch : {steps_per_epoch}")
    total_steps = steps_per_epoch * ExperimentConfig.epochs

    # Warmup: 10% of total steps
    warmup_steps = int(total_steps * 0.1)
    
    # CosineDecay schedule with warmup
    lr_schedule = keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=0.01 * ExperimentConfig.lr,  # very small starting LR
    decay_steps=total_steps - warmup_steps,   # decay after warmup
    alpha=ExperimentConfig.alpha,             
    warmup_target=ExperimentConfig.lr,
    warmup_steps=warmup_steps
    )

    model.compile(
        optimizer=keras.optimizers.AdamW(
            learning_rate=lr_schedule,
            weight_decay=ExperimentConfig.weight_decay,
        ),
        loss=BinaryDiceCELoss(
            from_logits=True,
            dice_weight=1.0,
            ce_weight=1.0,
            reduction="mean",
            num_classes=ExperimentConfig.num_classes,
        ),
        metrics=[
            BinaryDiceMetric(
                from_logits=True,
                ignore_empty=True,
                num_classes=ExperimentConfig.num_classes,
                name='dice',
            ),
            BinaryDiceMetric(
                from_logits=True,
                ignore_empty=True,
                target_class_ids=[0],
                num_classes=ExperimentConfig.num_classes,
                name='dice_tc',
            ),
            BinaryDiceMetric(
                from_logits=True,
                ignore_empty=True,
                target_class_ids=[1],
                num_classes=ExperimentConfig.num_classes,
                name='dice_wt',
            ),
            BinaryDiceMetric(
                from_logits=True,
                ignore_empty=True,
                target_class_ids=[2],
                num_classes=ExperimentConfig.num_classes,
                name='dice_et',
            )
        ],
    )

    return model

def get_inference_metric():
    swi_callback_metric = BinaryDiceMetric(
        from_logits=True,
        ignore_empty=True,
        num_classes=ExperimentConfig.num_classes,
        name='val_dice',
    )
    return swi_callback_metric



"""def run_sliding_window_inference_per_class_average(model, ds, roi_size, sw_batch_size, overlap, metrics_list):

#    Run sliding window inference on a dataset and compute all metrics (average + per class)

    for metric in metrics_list:
        metric.reset_states()
    
    swi = SlidingWindowInference(
        model,
        num_classes=metrics_list[0].num_classes,
        roi_size=roi_size,
        sw_batch_size=sw_batch_size,
        overlap=overlap
    )

    for x, y in ds:
        y_pred = swi(x)
        for metric in metrics_list:
            metric.update_state(ops.convert_to_tensor(y), ops.convert_to_tensor(y_pred))
    
    # Gather results
    results = {}
    for metric in metrics_list:
        results[metric.name] = float(ops.convert_to_numpy(metric.result()))
    
    return results"""


def main():
    # reproducibility
    keras.utils.set_random_seed(101)
    
    print(
        f"keras backend: {keras.config.backend()}\n"
        f"keras version: {keras.version()}\n"
        f"tensorflow version: {tf.__version__}\n"
    )
    
    
    # get keras backend
    keras_backend = keras.config.backend()
            
    strategy = tf.distribute.MirroredStrategy()
    total_device = strategy.num_replicas_in_sync
    
    print('Keras backend ', keras_backend)
    print('Total device found ', total_device)

    
    project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))  
    base_save_path = os.path.join(project_root, "experiments", "msd_brain")
    unetrplusplus_path = os.path.join(base_save_path, "unetrplusplus")
    os.makedirs(unetrplusplus_path, exist_ok=True)

    # Subfolders
    logs_path = os.path.join(unetrplusplus_path, "logs")
    history_path = os.path.join(unetrplusplus_path, "history")
    plots_path = os.path.join(unetrplusplus_path, "plots")
    os.makedirs(logs_path, exist_ok=True)
    os.makedirs(history_path, exist_ok=True)
    os.makedirs(plots_path, exist_ok=True)

    # Timestamp
    timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

    # Save path for best model weights
    save_path = os.path.join(unetrplusplus_path, f"best_model_weights_{timestamp}.weights.h5")

    # File for containing the learning history
    history_file = os.path.join(history_path, f"training_history.csv")
    
    
    # Load datasets
    tfrecord_pattern = os.path.join(project_root, "data", "msd_brain", "tfrecords", "{}_shard_*.tfrec")
    
    # batch size for training
    train_batch = ExperimentConfig.batch_size_train * total_device
    
    train_ds = data_loader(
                            tfrecord_pattern.format("training"),
                            batch_size=train_batch,
                            shuffle=True
                            )
    val_ds = data_loader(
                            tfrecord_pattern.format("validation"),
                            batch_size=ExperimentConfig.batch_size_val,
                            shuffle=False
                            )
    test_ds = data_loader(
                            tfrecord_pattern.format("test"),
                            batch_size=ExperimentConfig.batch_size_val,
                            shuffle=False
                            ) 
    
    with strategy.scope():
        model = get_model(total_device)
        

    checkpoint_dir = os.path.join(unetrplusplus_path, "checkpoints")
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    with strategy.scope():
        ckpt = tf.train.Checkpoint(
            epoch=tf.Variable(0),          # epoch counter — saved as part of checkpoint
            optimizer=model.optimizer,     # optimizer state
            model=model                    # model weights
        )
    
        ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=3)
            
        
        # Validation with sliding window callback
        swi_callback_metric = get_inference_metric()
        # Create SWI callback
        swi_callback = SlidingWindowInferenceCallback(
        model,
        dataset=val_ds,
        metrics=swi_callback_metric,
        num_classes=ExperimentConfig.num_classes,
        interval= ExperimentConfig.sliding_window_interval,
        overlap=ExperimentConfig.sliding_window_overlap,
        roi_size=(ExperimentConfig.input_shape[0],ExperimentConfig.input_shape[1],ExperimentConfig.input_shape[2]),
        sw_batch_size=ExperimentConfig.sw_batch_size * total_device ,
        save_path=save_path
        )
        
        
        # TFCheckpointCallback (save model, optimizer, epoch + SWI best score)
        tf_ckpt_callback = TFCheckpointCallback(ckpt, ckpt_manager, swi_callback, checkpoint_dir)
        
        # History callback
        # Load previous history if exists
        if os.path.exists(history_file):
            prev_history = pd.read_csv(history_file).to_dict(orient='list')
        else:
            prev_history = {}
        history_callback = HistorySaverCallback(history_file, initial_history=prev_history)
        
        
        # Resume or start from scratch
        if ckpt_manager.latest_checkpoint:
            ckpt.restore(ckpt_manager.latest_checkpoint)
            initial_epoch = int(ckpt.epoch.numpy())
            print(f"[Resume] Restored checkpoint: starting from epoch {initial_epoch}")
            
            # Restore SWI best score
            best_score_file = os.path.join(checkpoint_dir, "swi_best_score.json")
            if os.path.exists(best_score_file):
                with open(best_score_file, "r") as f:
                    swi_callback.best_score = json.load(f).get("best_score", -float("inf"))
                print(f"[Resume] Restored SWI best validation score: {swi_callback.best_score}")
            else:
                 print(f"[Resume] Couldn't Restore SWI best validation score")   
        else:
            initial_epoch = 0
            print("[Resume] No checkpoint found. Starting from scratch.")
    



    print(f"Model size: {model.count_params() / 1e6:.2f} M")
    
    start_time = time.time()
    
    with strategy.scope():
        history = model.fit(
        train_ds,
        epochs=ExperimentConfig.epochs,
        initial_epoch=initial_epoch,
        callbacks=[
            swi_callback,
            tf_ckpt_callback,
            history_callback
        ])

    

    end_time = time.time()
    training_time = end_time - start_time
    print(f"Total training time (seconds): {training_time:.2f}")



    # Save history to CSV
    full_history = history_callback.full_history
    # Save CSV
    pd.DataFrame(full_history).to_csv(history_file, index=False)

    # Plot loss
    plt.figure(figsize=(10, 5))
    plt.plot(full_history['loss'], label='train_loss')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training Loss")
    plt.legend()
    plt.grid()
    plt.savefig(os.path.join(plots_path, f"loss_curve_{timestamp}.png"))
    plt.close()


    # Plot average Dice
    if 'dice' in full_history:
        plt.figure(figsize=(10, 5))
        plt.plot(full_history['dice'], label='train_dice')
        plt.xlabel("Epoch")
        plt.ylabel("Average Dice")
        plt.title("Training Average Dice")
        plt.legend()
        plt.grid()
        plt.savefig(os.path.join(plots_path, f"dice_curve_{timestamp}.png"))
        plt.close()
        
    print("Training and saving plots finished successfully.")      


    

if __name__ == "__main__":
    main()    



    

To avoid mistakes, I currently put almost everything related to training inside strategy.scope(), including some objects where I am not sure whether they create TensorFlow variables or not.

Specifically, inside the scope I create:

  • The model

  • The optimizer

  • The loss

  • All training metrics

  • A metric used by a custom validation callback

  • Checkpoint objects (tf.train.Checkpoint, CheckpointManager)

  • Callbacks that reference the model and metrics

Datasets, paths, logging, and pure Python utilities are created outside the scope.

My current understadining is:

  1. Objects that create TensorFlow variables (model, optimizer, metrics) must be created inside strategy.scope().

  2. Objects that own or update metrics (e.g., custom callbacks that track validation scores) should also be created inside the scope.

  3. Checkpoint objects should be created inside the scope so they correctly track distributed variables.

  4. Dataset creation does not need to be inside the scope.

So my maun concern is that there are some objects where I am not 100% sure whether they create shared TensorFlow variables internally (for example, custom callbacks or utility classes that accept metrics or models).

Because of this uncertainty, I chose what seems like the safest option, which is putting everything I am unsure about inside strategy.scope().

So my question is: Is my code correct and safe for distributed multi-GPU training, or are there any mistakes in how I am using tf.distribute.MirroredStrategy and strategy.scope()?

2
  • 1
    Hi @Ahmed, Yes, your code is safe, keeping active items like model, optimizer, metrics, checkpoints inside strategy.scope() is the correct and standard approach. One minor suggestion is to use .batch(batch_size, drop_remainder=True) so the last uneven batch does nott crash the GPUs, and add .prefetch(tf.data.AUTOTUNE) so your training does not wait for data. Commented Jan 5 at 6:22
  • @Sagar Thank you so much for your help ! Much appreciated ! Commented Jan 6 at 9:56

0

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.