-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
76 lines (73 loc) · 3.83 KB
/
train.py
File metadata and controls
76 lines (73 loc) · 3.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
EPOCHS = 10
MOSAIC = 0.4
OPTIMIZER = 'AdamW'
MOMENTUM = 0.9
LR0 = 0.0001
LRF = 0.0001
SINGLE_CLS = False
import argparse
from ultralytics import YOLO
import os
import sys
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# epochs
parser.add_argument('--epochs', type=int, default=EPOCHS, help='Number of epochs')
# mosaic
parser.add_argument('--mosaic', type=float, default=MOSAIC, help='Mosaic augmentation')
# optimizer
parser.add_argument('--optimizer', type=str, default=OPTIMIZER, help='Optimizer')
# momentum
parser.add_argument('--momentum', type=float, default=MOMENTUM, help='Momentum')
# lr0
parser.add_argument('--lr0', type=float, default=LR0, help='Initial learning rate')
# lrf
parser.add_argument('--lrf', type=float, default=LRF, help='Final learning rate')
# single_cls
parser.add_argument('--single_cls', type=bool, default=SINGLE_CLS, help='Single class training')
args = parser.parse_args()
this_dir = os.path.dirname(__file__)
os.chdir(this_dir)
model = YOLO(os.path.join(this_dir, "yolov8s.pt"))
results = model.train(
data=os.path.join(this_dir, "yolo_params.yaml"),
epochs=args.epochs,
device=0,
single_cls=args.single_cls,
mosaic=args.mosaic,
optimizer=args.optimizer,
lr0 = args.lr0,
lrf = args.lrf,
momentum=args.momentum
)
'''
Mixup boost val pred but reduces test pred
Mosaic shouldn't be 1.0
'''
'''
from n params module arguments
0 -1 1 464 ultralytics.nn.modules.conv.Conv [3, 16, 3, 2]
1 -1 1 4672 ultralytics.nn.modules.conv.Conv [16, 32, 3, 2]
2 -1 1 7360 ultralytics.nn.modules.block.C2f [32, 32, 1, True]
3 -1 1 18560 ultralytics.nn.modules.conv.Conv [32, 64, 3, 2]
4 -1 2 49664 ultralytics.nn.modules.block.C2f [64, 64, 2, True]
5 -1 1 73984 ultralytics.nn.modules.conv.Conv [64, 128, 3, 2]
6 -1 2 197632 ultralytics.nn.modules.block.C2f [128, 128, 2, True]
7 -1 1 295424 ultralytics.nn.modules.conv.Conv [128, 256, 3, 2]
8 -1 1 460288 ultralytics.nn.modules.block.C2f [256, 256, 1, True]
9 -1 1 164608 ultralytics.nn.modules.block.SPPF [256, 256, 5]
10 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
11 [-1, 6] 1 0 ultralytics.nn.modules.conv.Concat [1]
12 -1 1 148224 ultralytics.nn.modules.block.C2f [384, 128, 1]
13 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
14 [-1, 4] 1 0 ultralytics.nn.modules.conv.Concat [1]
15 -1 1 37248 ultralytics.nn.modules.block.C2f [192, 64, 1]
16 -1 1 36992 ultralytics.nn.modules.conv.Conv [64, 64, 3, 2]
17 [-1, 12] 1 0 ultralytics.nn.modules.conv.Concat [1]
18 -1 1 123648 ultralytics.nn.modules.block.C2f [192, 128, 1]
19 -1 1 147712 ultralytics.nn.modules.conv.Conv [128, 128, 3, 2]
20 [-1, 9] 1 0 ultralytics.nn.modules.conv.Concat [1]
21 -1 1 493056 ultralytics.nn.modules.block.C2f [384, 256, 1]
22 [15, 18, 21] 1 751507 ultralytics.nn.modules.head.Detect [1, [64, 128, 256]]
Model summary: 225 layers, 3,011,043 parameters, 3,011,027 gradients, 8.2 GFLOPs
'''