-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
48 lines (37 loc) · 1.36 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""
Copyright 2020 Twitter, Inc.
SPDX-License-Identifier: Apache-2.0
Modified by Daeho Um (daehoum1@snu.ac.kr)
"""
def train(model, x, data, optimizer, critereon, train_loader=None, device="cuda"):
model.train()
return (
train_sampled(model, train_loader, x, data, optimizer, critereon, device)
if train_loader
else train_full_batch(model, x, data, optimizer, critereon)
)
def train_full_batch(model, x, data, optimizer, critereon):
model.train()
optimizer.zero_grad()
y_pred = model(x, data.edge_index)[data.train_mask]
y_true = data.y[data.train_mask].squeeze()
loss = critereon(y_pred, y_true)
loss.backward()
optimizer.step()
return loss
def train_sampled(model, train_loader, x, data, optimizer, critereon, device):
model.train()
total_loss = 0
for batch_size, n_id, adjs in train_loader:
# `adjs` holds a list of `(edge_index, e_id, size)` tuples.
adjs = [adj.to(device) for adj in adjs]
x_batch = x[n_id]
optimizer.zero_grad()
y_pred = model(x_batch, adjs=adjs, full_batch=False)
y_true = data.y[n_id[:batch_size]].squeeze()
loss = critereon(y_pred, y_true)
loss.backward()
optimizer.step()
total_loss += loss.item()
logger.debug(f"Batch loss: {loss.item():.2f}")
return total_loss / len(train_loader)