1

I'm building a neural network to predict how an image will be partitioned during compression using VVC (Versatile Video Coding). The model takes a single Y-frame from a YUV420 image as input and uses a CSV file containing the ground truth block positions and sizes for training.

Input and Ground Truth

  • Input: A 1-frame YUV420 10-bit image.

  • Ground Truth: A CSV file with block positions, sizes, and additional partitioning flags.

example(388016_320x480_37.yuv) enter image description here

example(388016_320x480_37.csv) enter image description here

Problem Description:

I implemented train.py and dataset.py, but I'm encountering an error when setting batch_size > 1 in the DataLoader. With a batch size of 1, the model works correctly, but increasing the batch size leads to runtime errors.

Code Summary:

Below is a simplified version of my custom_collate_fn and DataLoader setup:

def custom_collate_fn(batch):
    frames = [item[0] for item in batch]  # Y-frame tensors
    blocks = [item[1] for item in batch]  # Block information
    frames = torch.stack(frames, dim=0)  # Stacking frames along batch dimension
    return frames, blocks

dataloader = DataLoader(
    dataset,
    batch_size=batch_size, 
    shuffle=True,
    collate_fn=custom_collate_fn
)

Observations:

  • When batch_size = 1, the blocks_batch in the training loop is a list containing a single set of block data.

  • With batch_size > 1, it becomes a list of lists, causing errors when indexing.

for i, (frame, blocks_batch) in enumerate(dataloader):
    frame = frame.to(device)  # Shape: [batch_size, 1, H, W]
    blocks = blocks_batch[0]  # Works with batch_size=1 but fails with larger sizes

My Assumption:

It seems the issue arises from handling blocks_batch when batch_size > 1. The nested list structure makes it difficult to handle multiple batches.

Questions:

  • How can I adjust the custom_collate_fn or the training loop to handle batch sizes greater than 1 effectively?

  • If there's a better approach to batch-wise handling of variable-length block data, I'd appreciate any advice.

  File "C:\Users\Administrator\Documents\VVC_fast\test4\train.py", line 91, in <module>
    loss1 = criterion(out_split, target_split)
  File "C:\ProgramData\miniconda3\envs\newenv3\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\ProgramData\miniconda3\envs\newenv3\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\ProgramData\miniconda3\envs\newenv3\lib\site-packages\torch\nn\modules\loss.py", line 725, in forward
    return F.binary_cross_entropy_with_logits(input, target,
  File "C:\ProgramData\miniconda3\envs\newenv3\lib\site-packages\torch\nn\functional.py", line 3193, in binary_cross_entropy_with_logits
    raise ValueError(f"Target size ({target.size()}) must be the same as input size ({input.size()})")
ValueError: Target size (torch.Size([1, 1])) must be the same as input size (torch.Size([2, 1]))
1
  • you can either reimplement collate_fn or you can implement a wrapper around your loss function that converts between the label format returned by the existing collate function and the form expected by the loss function you use Commented Mar 28, 2025 at 15:07

0

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.