From 6559078dcc1aa32ef2ceca607cd4eb6b1973e4ae Mon Sep 17 00:00:00 2001 From: yukang Date: Thu, 16 Feb 2023 11:16:51 +0800 Subject: [PATCH 1/2] Resume training from a checkpoint For example, torchrun --nnodes=1 --nproc_per_node=8 train.py --model DiT-XL/2 --data-path /path/to/imagenet/train --resume results/000/checkpoints/0100000.pt --- train.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index c69d08cf..612e8522 100644 --- a/train.py +++ b/train.py @@ -193,14 +193,28 @@ def main(args): model.train() # important! This enables embedding dropout for classifier-free guidance ema.eval() # EMA model should always be in eval mode - # Variables for monitoring/logging purposes: + epochs_start = 0 train_steps = 0 + + if args.resume: + if rank == 0: + print("=> loading checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(args.resume) + model.load_state_dict(_ddp_dict(checkpoint['model']), strict=True) + ema.load_state_dict(checkpoint['ema'], strict=True) + opt.load_state_dict(checkpoint['opt']) + del checkpoint + train_steps = int(args.resume.split('/')[-1].split('.')[0]) + epochs_start = int(train_steps / (len(dataset) / args.global_batch_size)) + print("=> loaded checkpoint '{}' (epochs {})".format(args.resume, epochs_start)) + + # Variables for monitoring/logging purposes: log_steps = 0 running_loss = 0 start_time = time() logger.info(f"Training for {args.epochs} epochs...") - for epoch in range(args.epochs): + for epoch in range(epochs_start, args.epochs): sampler.set_epoch(epoch) logger.info(f"Beginning epoch {epoch}...") for x, y in loader: @@ -273,5 +287,6 @@ def main(args): parser.add_argument("--num-workers", type=int, default=4) parser.add_argument("--log-every", type=int, default=100) parser.add_argument("--ckpt-every", type=int, default=50_000) + parser.add_argument("--resume", type=str, default='') args = parser.parse_args() main(args) From 3f14f1dc8804eed5b75e5496010686855066cf56 Mon Sep 17 00:00:00 2001 From: yukang Date: Tue, 21 Feb 2023 22:26:10 +0800 Subject: [PATCH 2/2] Update train.py --- train.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/train.py b/train.py index 612e8522..0ead71d9 100644 --- a/train.py +++ b/train.py @@ -112,6 +112,12 @@ def center_crop_arr(pil_image, image_size): crop_x = (arr.shape[1] - image_size) // 2 return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) +def _ddp_dict(_dict): + new_dict = {} + for k in _dict: + new_dict['module.' + k] = _dict[k] + return new_dict + ################################################################################# # Training Loop #