diff --git a/train.py b/train.py index c69d08cf..0ead71d9 100644 --- a/train.py +++ b/train.py @@ -112,6 +112,12 @@ def center_crop_arr(pil_image, image_size): crop_x = (arr.shape[1] - image_size) // 2 return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) +def _ddp_dict(_dict): + new_dict = {} + for k in _dict: + new_dict['module.' + k] = _dict[k] + return new_dict + ################################################################################# # Training Loop # @@ -193,14 +199,28 @@ def main(args): model.train() # important! This enables embedding dropout for classifier-free guidance ema.eval() # EMA model should always be in eval mode - # Variables for monitoring/logging purposes: + epochs_start = 0 train_steps = 0 + + if args.resume: + if rank == 0: + print("=> loading checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(args.resume) + model.load_state_dict(_ddp_dict(checkpoint['model']), strict=True) + ema.load_state_dict(checkpoint['ema'], strict=True) + opt.load_state_dict(checkpoint['opt']) + del checkpoint + train_steps = int(args.resume.split('/')[-1].split('.')[0]) + epochs_start = int(train_steps / (len(dataset) / args.global_batch_size)) + print("=> loaded checkpoint '{}' (epochs {})".format(args.resume, epochs_start)) + + # Variables for monitoring/logging purposes: log_steps = 0 running_loss = 0 start_time = time() logger.info(f"Training for {args.epochs} epochs...") - for epoch in range(args.epochs): + for epoch in range(epochs_start, args.epochs): sampler.set_epoch(epoch) logger.info(f"Beginning epoch {epoch}...") for x, y in loader: @@ -273,5 +293,6 @@ def main(args): parser.add_argument("--num-workers", type=int, default=4) parser.add_argument("--log-every", type=int, default=100) parser.add_argument("--ckpt-every", type=int, default=50_000) + parser.add_argument("--resume", type=str, default='') args = parser.parse_args() main(args)