-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
Copy pathdatapipe.py
133 lines (104 loc) · 4.34 KB
/
datapipe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# In this example, you will find data loading implementations using PyTorch
# DataPipes (https://pytorch.org/data/) across various tasks:
# (1) molecular graph data loading pipe
# (2) mesh/point cloud data loading pipe
# In particular, we make use of PyG's built-in DataPipes, e.g., for batching
# multiple PyG data objects together or for converting SMILES strings into
# molecular graph representations. We also showcase how to write your own
# DataPipe (i.e. for loading and parsing mesh data into PyG data objects).
import argparse
import csv
import os.path as osp
import time
from itertools import chain, tee
import torch
from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import (
FileLister,
FileOpener,
IterableWrapper,
)
from torch_geometric.data import Data, download_url, extract_zip
def molecule_datapipe() -> IterDataPipe:
# Download HIV dataset from MoleculeNet:
url = 'https://deepchemdata.s3-us-west-1.amazonaws.com/datasets'
root_dir = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data')
path = download_url(f'{url}/HIV.csv', root_dir)
datapipe = FileOpener([path], mode="rt")
# Convert CSV rows into dictionaries, skipping the header row
datapipe = datapipe.map(lambda file: (
dict(zip(["smiles", "activity", "HIV_active"], row))
for i, row in enumerate(csv.reader(file[1])) if i > 0 and row))
datapipe = IterableWrapper(chain.from_iterable(datapipe))
datapipe = datapipe.parse_smiles(target_key='HIV_active')
datapipe, = tee(datapipe, 1)
return IterableWrapper(datapipe)
@torch.utils.data.functional_datapipe('read_mesh')
class MeshOpener(IterDataPipe):
# A custom DataPipe to load and parse mesh data into PyG data objects.
def __init__(self, dp: IterDataPipe) -> None:
try:
import meshio # noqa: F401
import torch_cluster # noqa: F401
except ImportError as e:
raise ImportError(
"To run this example, please install required packages:\n"
"pip install meshio torch-cluster") from e
super().__init__()
self.dp = dp
def __iter__(self):
import meshio
for path in self.dp:
category = osp.basename(path).split('_')[0]
try:
mesh = meshio.read(path)
except UnicodeDecodeError:
# Failed to read the file because it is not in the expected OFF
# format.
continue
pos = torch.from_numpy(mesh.points).to(torch.float)
face = torch.from_numpy(mesh.cells[0].data).t().contiguous()
yield Data(pos=pos, face=face, category=category)
def mesh_datapipe() -> IterDataPipe:
# Download ModelNet10 dataset from Princeton:
url = 'http://vision.princeton.edu/projects/2014/3DShapeNets'
root_dir = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data')
path = download_url(f'{url}/ModelNet10.zip', root_dir)
root_dir = osp.join(root_dir, 'ModelNet10')
if not osp.exists(root_dir):
extract_zip(path, root_dir)
def is_train(path: str) -> bool:
return 'train' in path
datapipe = FileLister([root_dir], masks='*.off', recursive=True)
datapipe = datapipe.filter(is_train)
datapipe = datapipe.read_mesh()
datapipe, = tee(datapipe, 1)
datapipe = IterableWrapper(datapipe)
datapipe = datapipe.sample_points(1024) # Use PyG transforms from here.
datapipe = datapipe.knn_graph(k=8)
return datapipe
DATAPIPES = {
'molecule': molecule_datapipe,
'mesh': mesh_datapipe,
}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--task', default='molecule', choices=DATAPIPES.keys())
args = parser.parse_args()
datapipe = DATAPIPES[args.task]()
print('Example output:')
print(next(iter(datapipe)))
# Shuffling + Batching support:
datapipe = datapipe.shuffle()
datapipe = datapipe.batch_graphs(batch_size=32)
# The first epoch will take longer than the remaining ones...
print('Iterating over all data...')
t = time.perf_counter()
for batch in datapipe:
pass
print(f'Done! [{time.perf_counter() - t:.2f}s]')
print('Iterating over all data a second time...')
t = time.perf_counter()
for batch in datapipe:
pass
print(f'Done! [{time.perf_counter() - t:.2f}s]')