forked from gbdl/BBI
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
115 lines (98 loc) · 3.13 KB
/
utils.py
File metadata and controls
115 lines (98 loc) · 3.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
Some helper functions for PyTorch
"""
# Adapted from https://github.com/kuangliu/pytorch-cifar
import os
import sys
import time
import torch
class progressBar:
def __init__(self, bar_length=65.0, term_width=None):
self.bar_length = bar_length
self.last_time = time.time()
self.begin_time = self.last_time
if term_width:
self.term_width = term_width
else:
try:
_, term_width = os.popen("stty size", "r").read().split()
self.term_width = int(term_width)
except:
self.term_width = 30
def next(self, current, total, msg=None):
if current == 0:
self.begin_time = time.time() # Reset for new bar.
cur_len = int(self.bar_length * current / total)
rest_len = int(self.bar_length - cur_len) - 1
sys.stdout.write(" [")
for _ in range(cur_len):
sys.stdout.write("=")
sys.stdout.write(">")
for _ in range(rest_len):
sys.stdout.write(".")
sys.stdout.write("]")
cur_time = time.time()
step_time = cur_time - self.last_time
self.last_time = cur_time
tot_time = cur_time - self.begin_time
L = []
L.append(" Step: %s" % self.format_time(step_time))
L.append(" | Tot: %s" % self.format_time(tot_time))
if msg:
L.append(" | " + msg)
msg = "".join(L)
sys.stdout.write(msg)
for _ in range(self.term_width - int(self.bar_length) - len(msg) - 3):
sys.stdout.write(" ")
# Go back to the center of the bar.
for _ in range(self.term_width - int(self.bar_length / 2) + 2):
sys.stdout.write("\b")
sys.stdout.write(" %d/%d " % (current + 1, total))
if current < total - 1:
sys.stdout.write("\r")
else:
sys.stdout.write("\n")
sys.stdout.flush()
@staticmethod
def format_time(seconds):
days = int(seconds / 3600 / 24)
seconds = seconds - days * 3600 * 24
hours = int(seconds / 3600)
seconds = seconds - hours * 3600
minutes = int(seconds / 60)
seconds = seconds - minutes * 60
secondsf = int(seconds)
seconds = seconds - secondsf
millis = int(seconds * 1000)
f = ""
i = 1
if days > 0:
f += str(days) + "D"
i += 1
if hours > 0 and i <= 2:
f += str(hours) + "h"
i += 1
if minutes > 0 and i <= 2:
f += str(minutes) + "m"
i += 1
if secondsf > 0 and i <= 2:
f += str(secondsf) + "s"
i += 1
if millis > 0 and i <= 2:
f += str(millis) + "ms"
i += 1
if f == "":
f = "0ms"
return f
def L2(params, l2, device):
"""Computes L2-regularization.
Arguments:
params: Pytorch net paramateters
l2: L2 coefficient.
"""
if l2 <= 0.0:
return 0.0
l2_reg = torch.tensor(0.0, device=device)
for param in params:
l2_reg += param.norm(2) ** 2
return l2_reg * l2 * 0.5