forked from mfederici/dl-kit
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathwandb.py
More file actions
32 lines (28 loc) · 1.38 KB
/
wandb.py
File metadata and controls
32 lines (28 loc) · 1.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import pytorch_lightning.loggers as loggers
import matplotlib.pyplot as plt
import wandb
from code.loggers.log_entry import LogEntry, IMAGE_ENTRY, SCALARS_ENTRY, SCALAR_ENTRY, PLOT_ENTRY
class WandbLogger(loggers.WandbLogger):
def log(self, name: str, log_entry: LogEntry, global_step: int = None, counters: dict = None) -> None:
if counters is None:
entry = {}
else:
entry = {k: v for k, v in counters.items()}
entry['trainer/global_step'] = global_step
if log_entry.data_type == SCALAR_ENTRY:
entry[name] = log_entry.value
self.experiment.log(entry, commit=False)
elif log_entry.data_type == SCALARS_ENTRY:
for sub_name, v in log_entry.value.items():
entry['%s/%s' % (name, sub_name)] = v
self.experiment.log(entry, commit=False)
elif log_entry.data_type == IMAGE_ENTRY:
entry[name] = wandb.Image(log_entry.value)
self.experiment.log(data=entry, step=global_step, commit=False)
plt.close(log_entry.value)
elif log_entry.data_type == PLOT_ENTRY:
entry[name] = log_entry.value
self.experiment.log(data=entry, step=global_step, commit=False)
plt.close(log_entry.value)
else:
raise Exception('Data type %s is not recognized by WandBLogWriter' % log_entry.data_type)