Below I discuss two ways of iterating over the dataset, which though has been covered in different answers above, the below code should make things crystal clear
import torch
from torch.utils.data import Dataset, DataLoader
import itertools
def cycle(iterable):
while True:
for x in iterable:
yield x
class CustomImageDataset(Dataset):
def __init__(self):
self.my_list = [1,2,3,4,5,6]
def __len__(self):
return len(self.my_list)
def __getitem__(self, idx):
return self.my_list[idx]
def print_iterations(dataiter,batchsize):
for idx in range(20):
print(f'In iteration {idx+1} sample is {next(dataiter)}')
if (idx+1)%(6/batchsize)==0:
print('----')
def test(batchsize):
print(f'****** Batch size = {batchsize} **********')
train_dataloader = DataLoader(CustomImageDataset(), batch_size=batchsize, shuffle=True)
dataiter = cycle(train_dataloader) # Note I do not wrap "iter" before "cycle()"
print_iterations(dataiter,batchsize)
print('\n---> Custom cycle works fine i.e after exhaustions samples are shuffling\n\n')
dataiter = itertools.cycle(train_dataloader)
print_iterations(dataiter,batchsize)
print('\n---> itertools.cycle DOES NOT works fine i.e after exhaustions samples are NOT shuffling')
test(2)
test(1)
And the expected output is
****** Batch size = 2 **********
In iteration 1 sample is tensor([4, 1])
In iteration 2 sample is tensor([6, 3])
In iteration 3 sample is tensor([2, 5])
----
In iteration 4 sample is tensor([1, 3])
In iteration 5 sample is tensor([5, 4])
In iteration 6 sample is tensor([6, 2])
----
In iteration 7 sample is tensor([4, 1])
In iteration 8 sample is tensor([2, 6])
In iteration 9 sample is tensor([5, 3])
----
In iteration 10 sample is tensor([2, 1])
In iteration 11 sample is tensor([4, 3])
In iteration 12 sample is tensor([6, 5])
----
In iteration 13 sample is tensor([5, 2])
In iteration 14 sample is tensor([4, 6])
In iteration 15 sample is tensor([3, 1])
----
In iteration 16 sample is tensor([2, 1])
In iteration 17 sample is tensor([6, 5])
In iteration 18 sample is tensor([4, 3])
----
In iteration 19 sample is tensor([6, 3])
In iteration 20 sample is tensor([5, 1])
---> Custom cycle works fine i.e after exhaustions samples are shuffling
In iteration 1 sample is tensor([5, 4])
In iteration 2 sample is tensor([6, 2])
In iteration 3 sample is tensor([1, 3])
----
In iteration 4 sample is tensor([5, 4])
In iteration 5 sample is tensor([6, 2])
In iteration 6 sample is tensor([1, 3])
----
In iteration 7 sample is tensor([5, 4])
In iteration 8 sample is tensor([6, 2])
In iteration 9 sample is tensor([1, 3])
----
In iteration 10 sample is tensor([5, 4])
In iteration 11 sample is tensor([6, 2])
In iteration 12 sample is tensor([1, 3])
----
In iteration 13 sample is tensor([5, 4])
In iteration 14 sample is tensor([6, 2])
In iteration 15 sample is tensor([1, 3])
----
In iteration 16 sample is tensor([5, 4])
In iteration 17 sample is tensor([6, 2])
In iteration 18 sample is tensor([1, 3])
----
In iteration 19 sample is tensor([5, 4])
In iteration 20 sample is tensor([6, 2])
---> itertools.cycle DOES NOT works fine i.e after exhaustions samples are NOT shuffling
****** Batch size = 1 **********
In iteration 1 sample is tensor([3])
In iteration 2 sample is tensor([5])
In iteration 3 sample is tensor([4])
In iteration 4 sample is tensor([2])
In iteration 5 sample is tensor([6])
In iteration 6 sample is tensor([1])
----
In iteration 7 sample is tensor([5])
In iteration 8 sample is tensor([4])
In iteration 9 sample is tensor([3])
In iteration 10 sample is tensor([1])
In iteration 11 sample is tensor([2])
In iteration 12 sample is tensor([6])
----
In iteration 13 sample is tensor([3])
In iteration 14 sample is tensor([2])
In iteration 15 sample is tensor([1])
In iteration 16 sample is tensor([5])
In iteration 17 sample is tensor([4])
In iteration 18 sample is tensor([6])
----
In iteration 19 sample is tensor([1])
In iteration 20 sample is tensor([3])
---> Custom cycle works fine i.e after exhaustions samples are shuffling
In iteration 1 sample is tensor([3])
In iteration 2 sample is tensor([1])
In iteration 3 sample is tensor([6])
In iteration 4 sample is tensor([4])
In iteration 5 sample is tensor([5])
In iteration 6 sample is tensor([2])
----
In iteration 7 sample is tensor([3])
In iteration 8 sample is tensor([1])
In iteration 9 sample is tensor([6])
In iteration 10 sample is tensor([4])
In iteration 11 sample is tensor([5])
In iteration 12 sample is tensor([2])
----
In iteration 13 sample is tensor([3])
In iteration 14 sample is tensor([1])
In iteration 15 sample is tensor([6])
In iteration 16 sample is tensor([4])
In iteration 17 sample is tensor([5])
In iteration 18 sample is tensor([2])
----
In iteration 19 sample is tensor([3])
In iteration 20 sample is tensor([1])
---> itertools.cycle DOES NOT works fine i.e after exhaustions samples are NOT shuffling