10

I’d like to implement an infinite loop Dataset & DataLoader. Here’s what I tried:

class Infinite(Dataset):
    def __len__(self):
        return HPARAMS.batch_size
#         return 1<<30 # This causes huge memory usage.
    def __getitem__(self, idx):
        """Randomly generates one new example."""
        return sample_func_to_be_parallelized()

infinite_loader = DataLoader(
    dataset=Infinite(), 
    batch_size=HPARAMS.batch_size, 
    num_workers=16,
    worker_init_fn=lambda worker_id: np.random.seed(worker_id),  
)

while True:
    for idx, data in enumerate(infinite_loader):
        # forward + backward on "data"

As you can see, the main challenge here is the __len()__ method. If I put a large enough number there, like 1<<30, the symptom is memory usage will JUMP TO 10+GB on the first iteration of train loop. After a while the workers are killed presumably due to OOM.

If I put a small number there, like 1 or BATCH_SIZE, the sampled “data” in the train loop will be periodically duplicated. This is not what I want as I’d like new data to be generated & trained on at every iteration.

I’m guessing the culprit of the excessive memory usage is somewhere in the stack, a bunch of things are cached. Upon a casual look at the Python side of things I can’t pinpoint where.

Can someone advise what’s the best way to have what I want implemented? (Use DataLoader’s parallel loading, while simultaneously guaranteeing every batch loaded is entirely new.)

1
  • What's your sample_func_to_be_parallelized()?
    – kuzand
    Commented Feb 27, 2019 at 8:42

4 Answers 4

5

This seems to be working without periodically duplicating the data:

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

BATCH_SIZE = 2

class Infinite(Dataset):

    def __len__(self):
        return BATCH_SIZE

    def __getitem__(self, idx):
        return torch.randint(0, 10, (3,))


data_loader = DataLoader(Infinite(), batch_size=BATCH_SIZE, num_workers=16)

batch_count = 0
while True:
    batch_count += 1
    print(f'Batch {batch_count}:')

    data = next(iter(data_loader))
    print(data)
    # forward + backward on "data"  

    if batch_count == 5:
        break

Result:

Batch 1:
tensor([[4, 7, 7],
        [0, 8, 0]])
Batch 2:
tensor([[6, 8, 6],
        [2, 6, 7]])
Batch 3:
tensor([[6, 6, 2],
        [8, 7, 0]])
Batch 4:
tensor([[9, 4, 8],
        [2, 4, 1]])
Batch 5:
tensor([[9, 6, 1],
        [2, 7, 5]])

So I think the problem is in your function sample_func_to_be_parallelized().


Edit: If instead of torch.randint(0, 10, (3,)) I use np.random.randint(10, size=3) in __getitem__ (as an example of the sample_func_to_be_parallelized()), then the data is indeed duplicated at each batch. See this issue.

So if you are using numpy's RGN somewhere in your sample_func_to_be_parallelized(), then the workaround is to use

worker_init_fn=lambda worker_id: np.random.seed(np.random.get_state()[1][0] + worker_id) 

and to reset the seed by np.random.seed() before each call of data = next(iter(data_loader)).

1
  • The reason for the duplication happening when using numpy is documented here: pytorch.org/docs/stable/data.html#data-loading-randomness Basically when the workers are initialized they copy the random state of the main thread which makes them generate the same random numbers. Pytorch however does reinitialize pyorch's random state, which is why this is not an issue when you use pytorche's randomness in the worker threads.
    – alvitawa
    Commented Sep 1, 2022 at 18:29
1

DataLoader samples your dataset without replacement. To do this, it generates a random permutation of indices between 0 and len(dataset). My guess that this permutation is responsible for eating up most of your memory. I don't think PyTorch APIs support infinite collections, but you could try forking the code in DataLoader and doing it yourself. You could use the batch_sampler param, and pass in a custom variant, implemented based on RandomSampler. This will allow you to keep the parallel loading part of DataLoader.

That being said, the protocol of iteration based on __len__ and __getitem__ just isn't suited for infinite collections. You may be better off reimplementing your Dataset.__len__ to just return 1, your Dataset.__getitem__ to always return a new sample, regardless of the index, and then sampling n times with replacement from this dataset. Technically, it will ask n times for the 0-th sample, but since you override __getitem__ to return different samples, this will effectively do what you're looking for.

1
  • 1
    Unfortunately, I did try overriding __len__() to always return 1 and __getitem__() to always return a new sample -- it did not work. The same sample was always returned. Hypothesis is PyTorch caches this somewhere, but I didn't see where.
    – Covi
    Commented Jan 25, 2019 at 17:12
1

Try to use cycle from itertools. Here is an example for simple dataset:

Code:

from itertools import cycle

import torch
from torch.utils.data import Dataset, DataLoader


# Create some dummy data.
data = torch.tensor([[0, 0],
                     [1, 1],
                     [2, 2],
                     [3, 3]])


class DataSet(Dataset):
    """Our dataset. Iterates over tensor data"""

    def __init__(self, data):
        self.data = data
        self.n = self.data.shape[0]

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.data[idx]


bs = 1  # batch size
workers = 1  # number of workers

dataset = DataSet(data)
data_loader = DataLoader(dataset, batch_size=bs, shuffle=False, num_workers=workers)

# Infinite loop.
print(f'batch size: {bs} | number of workers: {workers}')
for i, data in cycle(enumerate(data_loader)):
    print(i, data)

Output:

batch size: 1 | number of workers: 1
0 tensor([[0, 0]])
1 tensor([[1, 1]])
2 tensor([[2, 2]])
3 tensor([[3, 3]])
0 tensor([[0, 0]])
1 tensor([[1, 1]])
2 tensor([[2, 2]])
3 tensor([[3, 3]])
...

batch size: 2 | number of workers: 2
0 tensor([[0, 0],
        [1, 1]])
1 tensor([[2, 2],
        [3, 3]])
0 tensor([[0, 0],
        [1, 1]])
1 tensor([[2, 2],
...
0

This loader iterates over a list infinite times, also if the shuffle variable is set to True, in the next iteration the list elements are shuffled.

from torch.utils.data import DataLoader, Dataset, Sampler
import random

class listDataset(Dataset):
    def __init__(self):
        self.varList = [1,2,3,4]
    def __len__(self):
        return len(self.varList)
    def __getitem__(self, idx) :
        return self.varList[idx]

class customSampler(Sampler) :
    def __init__(self, dataset, shuffle):
        assert len(dataset) > 0
        self.dataset = dataset
        self.shuffle = shuffle

    def __iter__(self):
        order = list(range((len(self.dataset))))
        idx = 0
        while True:
            yield order[idx]
            idx += 1
            if idx == len(order):
                if self.shuffle:
                    random.shuffle(order)
                idx = 0

dset = listDataset()
sampler = customSampler(dset, shuffle=True)
loader = iter(DataLoader(dataset=dset, sampler=sampler, batch_size=6, num_workers=2))
for x in range(10):
    i = next(loader)
    print(i)
1
  • 1
    Can you add some explanation for what the code is doing, and how it's doing it differently than the existing answers? It looks good, but if people are reading the answers it helps to know why yours is better/unique.
    – Kaia
    Commented Feb 21, 2023 at 18:43

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.