-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
Copy pathsign.py
95 lines (71 loc) · 2.68 KB
/
sign.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os.path as osp
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch.utils.data import DataLoader
import torch_geometric.transforms as T
from torch_geometric.datasets import Flickr
K = 2
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Flickr')
transform = T.Compose([T.NormalizeFeatures(), T.SIGN(K)])
dataset = Flickr(path, transform=transform)
data = dataset[0]
train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)
val_idx = data.val_mask.nonzero(as_tuple=False).view(-1)
test_idx = data.test_mask.nonzero(as_tuple=False).view(-1)
train_loader = DataLoader(train_idx, batch_size=16 * 1024, shuffle=True)
val_loader = DataLoader(val_idx, batch_size=32 * 1024)
test_loader = DataLoader(test_idx, batch_size=32 * 1024)
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.lins = torch.nn.ModuleList()
for _ in range(K + 1):
self.lins.append(Linear(dataset.num_node_features, 1024))
self.lin = Linear((K + 1) * 1024, dataset.num_classes)
def forward(self, xs):
hs = []
for x, lin in zip(xs, self.lins):
h = lin(x).relu()
h = F.dropout(h, p=0.5, training=self.training)
hs.append(h)
h = torch.cat(hs, dim=-1)
h = self.lin(h)
return h.log_softmax(dim=-1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train():
model.train()
total_loss = total_examples = 0
for idx in train_loader:
xs = [data.x[idx].to(device)]
xs += [data[f'x{i}'][idx].to(device) for i in range(1, K + 1)]
y = data.y[idx].to(device)
optimizer.zero_grad()
out = model(xs)
loss = F.nll_loss(out, y)
loss.backward()
optimizer.step()
total_loss += float(loss) * idx.numel()
total_examples += idx.numel()
return total_loss / total_examples
@torch.no_grad()
def test(loader):
model.eval()
total_correct = total_examples = 0
for idx in loader:
xs = [data.x[idx].to(device)]
xs += [data[f'x{i}'][idx].to(device) for i in range(1, K + 1)]
y = data.y[idx].to(device)
out = model(xs)
total_correct += int((out.argmax(dim=-1) == y).sum())
total_examples += idx.numel()
return total_correct / total_examples
for epoch in range(1, 201):
loss = train()
train_acc = test(train_loader)
val_acc = test(val_loader)
test_acc = test(test_loader)
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '
f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')