-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrainer.py
More file actions
106 lines (74 loc) · 2.86 KB
/
trainer.py
File metadata and controls
106 lines (74 loc) · 2.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import logging
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
logging.basicConfig(level = logging.INFO, format = "%(asctime)s - %(message)s")
logger = logging.getLogger(__name__)
def train(model, optimizer, train_dl, val_dl, device=None, epochs=10):
model.cuda(device)
history = {
'acc': [], 'loss': [],
'val_acc': [], 'val_loss': []
}
batch_num = int(len(train_dl.dataset) / train_dl.batch_size)
for epoch in range(1, epochs + 1):
model.train()
steps = 0
total_loss = 0.
correct_num = 0
for (x, y) in train_dl:
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
scores = model(x)
loss = F.cross_entropy(scores, y)
loss.backward()
optimizer.step()
total_loss += loss.item()
y_pred = torch.max(scores, 1)[1]
correct_num += (y_pred == y).sum().item()
steps += 1
if steps % 100 == 0:
info = 'epoch {:<2}: {:.2%}'.format(epoch, steps / batch_num)
sys.stdout.write('\b' * len(info))
sys.stdout.write(info)
sys.stdout.flush()
sys.stdout.write('\b' * len(info))
sys.stdout.flush()
train_acc = correct_num / len(train_dl.dataset)
train_loss = total_loss / len(train_dl.dataset)
history['acc'].append(train_acc)
history['loss'].append(train_loss)
val_loss, val_acc = evaluate(model, val_dl, device=device)
history['val_acc'].append(val_acc)
history['val_loss'].append(val_loss)
logger.info("epoch {} - loss: {:.2f} acc: {:.2f} - val_loss: {:.2f} val_acc: {:.2f}"\
.format(epoch, train_loss, train_acc, val_loss, val_acc))
return history
def predict(model, dl, device=None):
model.eval()
y_pred = []
for x, _ in dl:
x = x.to(device)
scores = model(x)
y_pred_batch = torch.max(scores, 1)[1]
y_pred.append(y_pred_batch)
y_pred = torch.cat(y_pred, dim=0)
return y_pred.cpu().numpy()
def evaluate(model, dl, device=None):
model.eval()
total_loss = 0.0
correct_num = 0
for x, y in dl:
x = x.to(device)
y = y.to(device)
scores = model(x)
loss = F.cross_entropy(scores, y)
total_loss += loss.item()
y_pred = torch.max(scores, 1)[1]
correct_num += (y_pred == y).sum().item()
avg_loss = total_loss / len(dl.dataset)
avg_acc = correct_num / len(dl.dataset)
return avg_loss, avg_acc