Here is how to make elements of your dataloader customized.
Let's start by how default torch dataloader looks like:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
Let's first look how to create your "dataset". My example follows the logic in this tutorial but for your dataset structure (with A and B image folder):
https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
from torch.utils.data import Dataset
class GANDataset(Dataset):
def __init__(self, csv_file, sampler, transform, mode='train'):
"""
Arguments:
csv_file: Path to the csv file,
in your case, let's say it has two columns "A" and "B" and, in the case of your training set, with 8000 rows. For example, item 0 at column A is the full path to your first training input image,
and item 0 at column B is the full path to your first training ground truth image.
transform: transforms to be applied on a sample. I will show you a minimal transform composite later.
sampler: you need some sort of sampler, which can be a costum one or one of torch samplers. I will show how a simple sampler looks like.
mode: 'train' or 'inference'
"""
self.dataset = read_data_from_csv(csv_file)
self.dataset_indices = sampler.return_indices(self.dataset, mode)
self.mode = mode
self.transform = transform
def __len__(self):
return len(self.dataset_indices)
def __getitem__(self, idx):
img_A, img_B = self.read_images(idx)
to_move_on = False
if self.mode == 'train':
if img_A is not None and img_B is not None:
to_move_on = True
if self.mode == 'inference':
if img_A is not None:
to_move_on = True
if to_move_on: #in training mode both input and target should be not None, in inference mode only the input should not be None.
data_sample = copy.deepcopy(self.dataset[idx])
data_sample['img_B'] = img_B
data_sample['img_A'] = img_A
data_sample = self.transform(data_sample)
return data_sample
else: #if you fail for any reason to read this sample you can pick up another idx so that your training process doesn't stop.
another_idx = random.randint(0, len(self.dataset_indices) - 1)
return self[self.dataset_indices[another_idx]]
def read_images(self, idx):
img_A = None
img_B = None
try:
if self.dataset[idx]['A'] is not None:
img_A = cv2.imread(self.dataset[idx]['A'])
except:
img_A = None
try:
if self.dataset[idx]['B'] is not None:
img_B = cv2.imread(self.dataset[idx]['B'])
except:
img_B = None
return img_A, img_B
The function "read_data_from_csv" should be defined by you to build a dict-like dataset from your csv file. It can be as simple as this:
def read_data_from_csv(csv_file):
df= pd.read_csv(csv_file)
dataset = []
for index, row in df.iterrows():
if not pd.isnull(row['A']) and row['A'] is not None:
A_file = row['A'])
else:
A_file = None
if not pd.isnull(row['B']) and row['B'] is not None:
B_file = row['B']
else:
B_file = None
if A_file is not None or B_file is not None:
new_data_item = {'A': A_file, 'B': B_file,
'img_A': None, 'img_B': None}
dataset.append(new_data_item)
return dataset
Let's look at a simple sampler with the most basic possibilit to sample, which should be enough for a GAN model. Look here for more: https://github.com/pytorch/pytorch/blob/main/torch/utils/data/sampler.py
class NaiveSampler(object):
def return_indices(self, dataset, mode):
dataset_indices = [idx for idx in range(len(dataset))]
if mode == 'train':
random.shuffle(dataset_indices)
return dataset_indices
Let's look at basic transforms. This example first decides whether a data augmentation should happen or not, in my case I want the sample to be flipped by a
certain probability (part of your training hyper parameters) and then some other essential changes (like normalization, standardization and moving to torch tensor from numpy array). For mode details look at this tutorial: https://pytorch.org/vision/stable/transforms.html
transform=transforms.Compose([FlipAugmentation(hflip_probability, vflip_probability),
SampleGenerator(mean_A, std_A, mean_B, std_B)])
Where, my augmentation class could look like this:
class FlipAugmentation(object):
def __init__(self, hflip_probability, vflip_probability):
augmentation_list = []
self.sequence_augmentation = lambda x: x
if hflip_probability > 0:
augmentation_list.append(HorizontalFlip(hflip_probability))
if vflip_probability > 0:
augmentation_list.append(VerticalFlip(vflip_probability))
if len(augmentation_list) > 0:
def compose(g, f):
return lambda x: g(f(x))
self.sequence_augmentation = reduce(compose, augmentation_list, lambda x: x)
def __call__(self, sample):
return self.sequence_augmentation(sample)
class HorizontalFlip(object):
def __init__(self, hflip_probability):
self.hflip_probability = hflip_probability
def __call__(self, sample):
if np.random.rand() < self.hflip_probability:
returned_sample = copy.deepcopy(sample)
img_A = returned_sample['img_A']
img_A = img_A[:, ::-1, :]
returned_sample['img_A'] = img_A
img_B = returned_sample['img_B']
img_B = img_B[:, ::-1, :]
returned_sample['img_B'] = img_B
return returned_sample
else:
return sample
class VecrticalFlip(object):
def __init__(self, vflip_probability):
self.vflip_probability = vflip_probability
def __call__(self, sample):
if np.random.rand() < self.vflip_probability:
returned_sample = copy.deepcopy(sample)
img_A = returned_sample['img_A']
img_A = img_A[::-1, :, :]
returned_sample['img_A'] = img_A
img_B = returned_sample['img_B']
img_B = img_B[::-1,:, :]
returned_sample['img_B'] = img_B
return returned_sample
else:
return sample
And my sample generator is as simple as this, which standardizes input and target images:
class SampleGenerator(object):
def __init__(self, mean_A, std_A, mean_B, std_B):
self.normalize_means_stds = [mean_A, std_A, mean_B, std_B]
def __call__(self, sample):
img_A, img_B = sample['img_A'], sample['img_B']
sample_called = dict(sample)
img_A = (((img_A.astype(np.float32) / 255.0) - self.normalize_means_stds[0]) / self.normalize_means_stds[1])
sample_called['img_A'] = torch.from_numpy(img_A)
if img_B is not None:
img_B = (((img_B.astype(np.float32) / 255.0) - self.normalize_means_stds[2]) / self.normalize_means_stds[3])
sample_called['img_B'] = torch.from_numpy(img_B)
return sample_called
You finally need a batch sampler too, which can again be as simple as this (for more details look at here: https://medium.com/@haleema.ramzan/how-to-build-a-custom-batch-sampler-in-pytorch-ce04161583ee#:~:text=from%20torch.utils.data.sampler%20import%20Sampler-,class%20CustomBatchSampler(Sampler)%3A,-r%22%22%22Yield%20a%20mini%2Dbatch%20of%20indices.%20The):
from torch.utils.data.sampler import Sampler
class CustomBatchSampler(Sampler):
def __init__(self, dataset, batch_size):
self.batch_size = batch_size
self.data = data
index_list = self.data.dataset_indices
number_batches = self.__len__()
batch_list = []
for start in range(0, number_batches * self.batch_size, self.batch_size):
batch = []
for b in range(start, start + self.batch_size):
batch.append(index_list[b % len(index_list)])
batch_list.append(batch)
self.batch_list = batch_list
def __iter__(self):
for batch in self.batch_list:
yield batch
def __len__(self):
return len(self.data) // self.batch_size
It has now all come together and you can create your dataloader:
csv_file = 'path/to/your/file.csv'
sampler = NaiveSampler()
transform = transforms.Compose([FlipAugmentation(0.3, 0.3),
SampleGenerator(0, 1.0, 0, 1.0)])
train_dataset = GANDataset(csv_file, sampler, transform, mode='train')
batch_sampler = CustomBatchSampler(train_dataset, batch_size=16)
train_dataloader = DataLoader(train_dataset, num_workers=0,
batch_sampler=batch_sampler,
pin_memory=torch.cuda.is_available())