diff --git a/command_parameter.py b/command_parameter.py new file mode 100644 index 0000000..f5d2dd8 --- /dev/null +++ b/command_parameter.py @@ -0,0 +1,88 @@ +import argparse + +# define of command parameters (may write as moudle) +parser = argparse.ArgumentParser(description='PyTorch training') + +# model +model_name = ['vgg','resnet','se_resnet'] +parser.add_argument('--model', type=str, default='vgg', choices=model_name, + help='model (default: vgg)') +# dataset +parser.add_argument('--dataset', type=str, default='cifar10', + help='dataset (default: cifar10)') +# sr +parser.add_argument('--sparsity-regularization', '-sr', dest='sr', action='store_true', + help='train with channel sparsity regularization') +# se +parser.add_argument('-se', dest='se', action='store_true', + help='train with SEBlock') +# penalty +parser.add_argument('--p', type=float, default=0.0001, + help='penalty (default: 0.0001)') +# batch-size +parser.add_argument('--batch-size', type=int, default=100, metavar='N', + help='input batch size for training (default: 100)') +# fine tune +parser.add_argument('--fine-tune', default='', type=str, metavar='PATH', + help='fine-tune from pruned model') +# validation batch size +parser.add_argument('--validate-batch-size', type=int, default=1000, metavar='N', + help='input batch size for validation (default: 1000)') +# epoch +parser.add_argument('--epochs', type=int, default=160, metavar='N', + help='number of epochs to train (default: 160)') +# start-epoch +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +# learning rate +parser.add_argument('--lr', type=float, default=0.1, metavar='LR', + help='learning rate (default: 0.1)') +# momentum +parser.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='SGD momentum (default: 0.9)') +# weight decay +parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)') +# resume +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +# no cuda +parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') +# seed +parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') +# log interval +parser.add_argument('--log-interval', type=int, default=100, metavar='N', + help='how many batches to wait before logging training status') +# save path +parser.add_argument('--save-path', type=str, default='./save/', metavar='PATH', + help='path to save checkpoint') +# num workers(thread) +parser.add_argument('--num-workers', type=int, default=1, + help='how many thread to load data(default: 1)') +# num classes +parser.add_argument('--num_classes', type=int, default=10) + +# images path +parser.add_argument('--image-root-path', default='', type=str, metavar='PATH', + help='path to root path of images (default: none)') +# images train list +parser.add_argument('--image-train-list', default='', type=str, metavar='PATH', + help='path to training list (default: none)') +# validation list +parser.add_argument('--image-validate-list',default='',type=str,metavar='PATH', + help='path to validation list (default: none)') +# image size +parser.add_argument('--img-size', '--img_size', default=144, type=int) + +# crop size !!!crop size delete from command parameter,write into transfrom_config.xml +parser.add_argument('--crop-size', '--crop_size', default=128, type=int) + +# teacher model +parser.add_argument('--teacher_model', default=None, type=str, metavar='PATH', + help='teacher model for knowledge distillation') +# loss ratio +parser.add_argument( '--loss_ratio', default=0.2, type=float, + help='ratio to control knowledge distillation\'s loss') +# end of define command parameters \ No newline at end of file diff --git a/command_parameter_bk-8-2.py b/command_parameter_bk-8-2.py new file mode 100644 index 0000000..f5d2dd8 --- /dev/null +++ b/command_parameter_bk-8-2.py @@ -0,0 +1,88 @@ +import argparse + +# define of command parameters (may write as moudle) +parser = argparse.ArgumentParser(description='PyTorch training') + +# model +model_name = ['vgg','resnet','se_resnet'] +parser.add_argument('--model', type=str, default='vgg', choices=model_name, + help='model (default: vgg)') +# dataset +parser.add_argument('--dataset', type=str, default='cifar10', + help='dataset (default: cifar10)') +# sr +parser.add_argument('--sparsity-regularization', '-sr', dest='sr', action='store_true', + help='train with channel sparsity regularization') +# se +parser.add_argument('-se', dest='se', action='store_true', + help='train with SEBlock') +# penalty +parser.add_argument('--p', type=float, default=0.0001, + help='penalty (default: 0.0001)') +# batch-size +parser.add_argument('--batch-size', type=int, default=100, metavar='N', + help='input batch size for training (default: 100)') +# fine tune +parser.add_argument('--fine-tune', default='', type=str, metavar='PATH', + help='fine-tune from pruned model') +# validation batch size +parser.add_argument('--validate-batch-size', type=int, default=1000, metavar='N', + help='input batch size for validation (default: 1000)') +# epoch +parser.add_argument('--epochs', type=int, default=160, metavar='N', + help='number of epochs to train (default: 160)') +# start-epoch +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +# learning rate +parser.add_argument('--lr', type=float, default=0.1, metavar='LR', + help='learning rate (default: 0.1)') +# momentum +parser.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='SGD momentum (default: 0.9)') +# weight decay +parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)') +# resume +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +# no cuda +parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') +# seed +parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') +# log interval +parser.add_argument('--log-interval', type=int, default=100, metavar='N', + help='how many batches to wait before logging training status') +# save path +parser.add_argument('--save-path', type=str, default='./save/', metavar='PATH', + help='path to save checkpoint') +# num workers(thread) +parser.add_argument('--num-workers', type=int, default=1, + help='how many thread to load data(default: 1)') +# num classes +parser.add_argument('--num_classes', type=int, default=10) + +# images path +parser.add_argument('--image-root-path', default='', type=str, metavar='PATH', + help='path to root path of images (default: none)') +# images train list +parser.add_argument('--image-train-list', default='', type=str, metavar='PATH', + help='path to training list (default: none)') +# validation list +parser.add_argument('--image-validate-list',default='',type=str,metavar='PATH', + help='path to validation list (default: none)') +# image size +parser.add_argument('--img-size', '--img_size', default=144, type=int) + +# crop size !!!crop size delete from command parameter,write into transfrom_config.xml +parser.add_argument('--crop-size', '--crop_size', default=128, type=int) + +# teacher model +parser.add_argument('--teacher_model', default=None, type=str, metavar='PATH', + help='teacher model for knowledge distillation') +# loss ratio +parser.add_argument( '--loss_ratio', default=0.2, type=float, + help='ratio to control knowledge distillation\'s loss') +# end of define command parameters \ No newline at end of file diff --git a/dataset_factory.py b/dataset_factory.py new file mode 100644 index 0000000..8037c33 --- /dev/null +++ b/dataset_factory.py @@ -0,0 +1,33 @@ +import torch +from torchvision import datasets +from utils.load_imglist import ImageList + +class dataset_factory(object): + + + @staticmethod + def get_train_loader_and_validate_loader(dataset_name, c_transform, arg_batch_size,arg_validate_batch_size, kwargs, + dataset_config = None, root_path='./data', fileList_path='./data'): + low_datasets = list((s.lower() for s in datasets.__all__)) + dataset_dict = dict(list(zip(low_datasets,datasets.__all__))) + print(low_datasets) + # get tran_loader and validate_loader + if dataset_name.lower() in low_datasets: + train_loader = torch.utils.data.DataLoader( + getattr(datasets,dataset_dict[dataset_name])(root_path, transform=c_transform), #root transforms,train[defult=true],download[defult=false] + batch_size=arg_batch_size, shuffle=True, **kwargs + ) + validate_loader = torch.utils.data.DataLoader( + getattr(datasets,dataset_dict[dataset_name])(root_path, train = False, transform=c_transform), + batch_size=arg_validate_batch_size, shuffle=False, **kwargs + ) + else: + train_loader = torch.utils.data.DataLoader( + ImageList(root=root_path, fileList=fileList_path, transform=c_transform), + batch_size=arg_batch_size, shuffle=True, **kwargs + ) + validate_loader = torch.utils.data.DataLoader( + ImageList(root=root_path, fileList=fileList_path, transform=c_transform), # !!!crop size delete from command parameter,write into transfrom_config.xml + batch_size=arg_validate_batch_size, shuffle=False, **kwargs + ) + return train_loader, validate_loader \ No newline at end of file diff --git a/main.py b/main.py index c423f2d..62bceed 100644 --- a/main.py +++ b/main.py @@ -1,207 +1,92 @@ import os -import argparse import torch import torch.nn as nn -# import torch.nn.functional as F import torch.optim as optim -from torchvision import datasets, transforms import torch.backends.cudnn as cudnn +from torchvision import transforms from utils.trainer import * from nets.my_vgg import vgg_diy -#from nets.varying_bn_vgg import vgg_varyingBN from nets.resnet_pre_activation import * from nets.se_resnet import * -from utils.load_imglist import ImageList from utils.convert_DataParallel_Model import convert_DataParallel_Model_to_Common_Model +from command_parameter import * +from dataset_factory import dataset_factory -# Training settings -parser = argparse.ArgumentParser(description='PyTorch training') - -parser.add_argument('--model', type=str, default='vgg', - help='model (default: vgg)') -parser.add_argument('--dataset', type=str, default='cifar10', - help='dataset (default: cifar10)') -parser.add_argument( - '--sparsity-regularization', - '-sr', - dest='sr', - action='store_true', - help='train with channel sparsity regularization') -parser.add_argument('-se', dest='se', action='store_true', - help='train with SEBlock') -parser.add_argument('--p', type=float, default=0.0001, - help='penalty (default: 0.0001)') -parser.add_argument('--batch-size', type=int, default=100, metavar='N', - help='input batch size for training (default: 100)') -parser.add_argument('--fine-tune', default='', type=str, metavar='PATH', - help='fine-tune from pruned model') -parser.add_argument( - '--validate-batch-size', - type=int, - default=1000, - metavar='N', - help='input batch size for validation (default: 1000)') -parser.add_argument('--epochs', type=int, default=160, metavar='N', - help='number of epochs to train (default: 160)') -parser.add_argument('--start-epoch', default=0, type=int, metavar='N', - help='manual epoch number (useful on restarts)') -parser.add_argument('--lr', type=float, default=0.1, metavar='LR', - help='learning rate (default: 0.1)') -parser.add_argument('--momentum', type=float, default=0.9, metavar='M', - help='SGD momentum (default: 0.9)') -parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)') -parser.add_argument('--resume', default='', type=str, metavar='PATH', - help='path to latest checkpoint (default: none)') -parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') -parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') -parser.add_argument( - '--log-interval', type=int, default=100, metavar='N', - help='how many batches to wait before logging training status') -parser.add_argument( - '--save-path', - type=str, - default='./save/', - metavar='PATH', - help='path to save checkpoint') -parser.add_argument('--num-workers', type=int, default=1, - help='how many thread to load data(default: 1)') -parser.add_argument('--num_classes', type=int, default=10) -parser.add_argument('--image-root-path', default='', type=str, metavar='PATH', - help='path to root path of images (default: none)') -parser.add_argument('--image-train-list', default='', type=str, metavar='PATH', - help='path to training list (default: none)') -parser.add_argument( - '--image-validate-list', - default='', - type=str, - metavar='PATH', - help='path to validation list (default: none)') -parser.add_argument('--img-size', '--img_size', default=144, type=int) -parser.add_argument('--crop-size', '--crop_size', default=128, type=int) -parser.add_argument( - '--teacher_model', default=None, type=str, metavar='PATH', - help='teacher model for knowledge distillation') -parser.add_argument( - '--loss_ratio', default=0.2, type=float, - help='ratio to control knowledge distillation\'s loss') +# parser from command_parameter args = parser.parse_args() -args.cuda = not args.no_cuda and torch.cuda.is_available() - +# setting command parameters +args.cuda = (not args.no_cuda) and torch.cuda.is_available() torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) cudnn.benchmark = True -model_dict = { - 'vgg': vgg_diy, - 'resnet': preactivation_resnet164, - 'se_resnet': se_resnet_34 -} - - -if args.model not in model_dict: - raise ValueError('Name of network unknown %s' % args.model) -else: - if args.fine_tune: - load_pkl = torch.load(args.fine_tune) - model = model_dict[args.model]( - num_classes=args.num_classes, cfg=load_pkl['cfg']) - model.load_state_dict(load_pkl['model_state_dict']) - if args.teacher_model is not None: - teacher_model = model_dict[args.model]( - num_classes=args.num_classes) - teacher_model.load_state_dict(torch.load(args.teacher_model)) - else: - pass - #model = model_dict[args.model](num_classes=args.num_classes) - # model.load_state_dict(load_pkl) - args.save_path = os.path.join( - args.save_path, - 'fine_tune/' + args.model, - args.dataset) +# loading model and choose model +model_class = [vgg_diy,preactivation_resnet164,se_resnet_34] +model_dict = dict(list(zip(model_name,model_class))) +if args.fine_tune: + load_pkl = torch.load(args.fine_tune) + model = model_dict[args.model]( + num_classes=args.num_classes, cfg=load_pkl['cfg']) + model.load_state_dict(load_pkl['model_state_dict']) + if args.teacher_model is not None: + teacher_model = model_dict[args.model]( + num_classes=args.num_classes) + teacher_model.load_state_dict(torch.load(args.teacher_model)) else: - model = model_dict[args.model](num_classes=args.num_classes) - args.save_path = os.path.join(args.save_path, args.model, args.dataset) + pass + #model = model_dict[args.model](num_classes=args.num_classes) + # model.load_state_dict(load_pkl) + args.save_path = os.path.join( + args.save_path, + 'fine_tune/' + args.model, + args.dataset) +else: + model = model_dict[args.model](num_classes=args.num_classes) + args.save_path = os.path.join(args.save_path, args.model, args.dataset) +# dataset choice kwargs = {'num_workers': args.num_workers, 'pin_memory': True} if args.cuda else {} -if args.dataset == 'cifar10': + +# normalize config +''' + cifar10 normalize normalize = transforms.Normalize( - mean=[0.491, 0.482, 0.447], - std=[0.247, 0.243, 0.262]) - train_loader = torch.utils.data.DataLoader( - datasets.CIFAR10('./data', train=True, download=False, - transform=transforms.Compose([ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - #transforms.ColorJitter(brightness=1), - transforms.ToTensor(), - normalize - ]) - ), batch_size=args.batch_size, shuffle=True, **kwargs - ) - validate_loader = torch.utils.data.DataLoader( - datasets.CIFAR10('./data', train=False, - transform=transforms.Compose([ - transforms.ToTensor(), - normalize - ]) - ), - batch_size=args.validate_batch_size, shuffle=False, **kwargs - ) -elif args.dataset == 'cifar100': + mean=[0.491, 0.482, 0.447], + std=[0.247, 0.243, 0.262]) + + cifar100 normalize normalize = transforms.Normalize( mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]) - # normalize = transforms.Normalize(mean=[.5,.5,.5], std=[.5,.5,.5]) - # normalize = transforms.Normalize((.5,.5,.5),(.5,.5,.5)) - train_loader = torch.utils.data.DataLoader( - datasets.CIFAR100('./data', train=True, download=True, - transform=transforms.Compose([ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize - ]) - ), batch_size=args.batch_size, shuffle=True, **kwargs - ) - validate_loader = torch.utils.data.DataLoader( - datasets.CIFAR100('./data', train=False, - transform=transforms.Compose([ - transforms.ToTensor(), - normalize - ]) - ), - batch_size=args.validate_batch_size, shuffle=False, **kwargs - ) -else: - train_loader = torch.utils.data.DataLoader( - ImageList(root=args.image_root_path, fileList=args.image_train_list, - transform=transforms.Compose([ - transforms.Resize(size=(args.img_size, args.img_size)), - transforms.RandomCrop(args.crop_size), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - ]) - ), - batch_size=args.batch_size, shuffle=True, **kwargs - ) - validate_loader = torch.utils.data.DataLoader( - ImageList( - root=args.image_root_path, fileList=args.image_validate_list, - transform=transforms.Compose( - [transforms.Resize( - size=(args.crop_size, args.crop_size)), - transforms.ToTensor(), ])), - batch_size=args.validate_batch_size, shuffle=False, **kwargs) - + or customize by ueser + ... + +''' +normalize = transforms.Normalize( + mean=[0.491, 0.482, 0.447], + std=[0.247, 0.243, 0.262]) + +# tansform config +''' + customize by ueser +''' +c_transform = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness=1), + transforms.ToTensor(), + normalize + ]) + +test_path = '/home/leolau/pytorch/data' #just for test, when parameter confirm this might delete +train_loader, validate_loader = dataset_factory.get_train_loader_and_validate_loader(args.dataset, c_transform, args.batch_size, args.validate_batch_size, + kwargs, root_path=test_path) optimizer = optim.SGD( filter( lambda p: p.requires_grad, @@ -251,7 +136,7 @@ SEBlock=SEBlock ) -elif args.fine_tune is not None and args.teacher_model is not None: +elif args.fine_tune is not None and args.teacher_model is not None: # other 3 00 01 10 print('\nTraining with Knowledge Distillation \n') trainer = Trainer( model=model, diff --git a/main_bk.py b/main_bk.py new file mode 100644 index 0000000..2d25ce2 --- /dev/null +++ b/main_bk.py @@ -0,0 +1,218 @@ +import os +import torch +import torch.nn as nn +import torch.optim as optim +import torch.backends.cudnn as cudnn +from utils.trainer import * +from nets.my_vgg import vgg_diy +from nets.resnet_pre_activation import * +from nets.se_resnet import * +from utils.convert_DataParallel_Model import convert_DataParallel_Model_to_Common_Model +from command_parameter import * +from dataset_factory import dataset_factory + + +# parser from command_parameter +args = parser.parse_args() + +# setting command parameters +args.cuda = (not args.no_cuda) and torch.cuda.is_available() +torch.manual_seed(args.seed) +if args.cuda: + torch.cuda.manual_seed(args.seed) +cudnn.benchmark = True + +# loading model and choose model +model_class = [vgg_diy,preactivation_resnet164,se_resnet_34] +model_dict = dict(list(zip(model_name,model_class))) +if args.fine_tune: + load_pkl = torch.load(args.fine_tune) + model = model_dict[args.model]( + num_classes=args.num_classes, cfg=load_pkl['cfg']) + model.load_state_dict(load_pkl['model_state_dict']) + if args.teacher_model is not None: + teacher_model = model_dict[args.model]( + num_classes=args.num_classes) + teacher_model.load_state_dict(torch.load(args.teacher_model)) + else: + pass + #model = model_dict[args.model](num_classes=args.num_classes) + # model.load_state_dict(load_pkl) + args.save_path = os.path.join( + args.save_path, + 'fine_tune/' + args.model, + args.dataset) +else: + model = model_dict[args.model](num_classes=args.num_classes) + args.save_path = os.path.join(args.save_path, args.model, args.dataset) + + +# dataset choice +kwargs = {'num_workers': args.num_workers, + 'pin_memory': True} if args.cuda else {} +''' +if args.dataset == 'cifar10': + normalize = transforms.Normalize( + mean=[0.491, 0.482, 0.447], + std=[0.247, 0.243, 0.262]) + train_loader = torch.utils.data.DataLoader( + datasets.CIFAR10('./data', train=True, download=False, + transform=transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + #transforms.ColorJitter(brightness=1), + transforms.ToTensor(), + normalize + ]) + ), batch_size=args.batch_size, shuffle=True, **kwargs + ) + validate_loader = torch.utils.data.DataLoader( + datasets.CIFAR10('./data', train=False, + transform=transforms.Compose([ + transforms.ToTensor(), + normalize + ]) + ), + batch_size=args.validate_batch_size, shuffle=False, **kwargs + ) +elif args.dataset == 'cifar100': + normalize = transforms.Normalize( + mean=[0.507, 0.487, 0.441], + std=[0.267, 0.256, 0.276]) + # normalize = transforms.Normalize(mean=[.5,.5,.5], std=[.5,.5,.5]) + # normalize = transforms.Normalize((.5,.5,.5),(.5,.5,.5)) + train_loader = torch.utils.data.DataLoader( + datasets.CIFAR100('./data', train=True, download=True, + transform=transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize + ]) + ), batch_size=args.batch_size, shuffle=True, **kwargs + ) + validate_loader = torch.utils.data.DataLoader( + datasets.CIFAR100('./data', train=False, + transform=transforms.Compose([ + transforms.ToTensor(), + normalize + ]) + ), + batch_size=args.validate_batch_size, shuffle=False, **kwargs + ) +else: + train_loader = torch.utils.data.DataLoader( + ImageList(root=args.image_root_path, fileList=args.image_train_list, + transform=transforms.Compose([ + transforms.Resize(size=(args.img_size, args.img_size)), + transforms.RandomCrop(args.crop_size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + ]) + ), + batch_size=args.batch_size, shuffle=True, **kwargs + ) + validate_loader = torch.utils.data.DataLoader( + ImageList( + root=args.image_root_path, fileList=args.image_validate_list, + transform=transforms.Compose( + [transforms.Resize( + size=(args.crop_size, args.crop_size)), + transforms.ToTensor(), + ]) + ), + batch_size=args.validate_batch_size, shuffle=False, **kwargs + ) +''' + +test_path = '/home/leolau/pytorch/data' #just for test, when parameter confirm this might delete +dataset_f = dataset_factory(args.dataset,args.batch_size,args.validate_batch_size,kwargs,root_path=test_path) +train_loader = dataset_f.train_loader +validate_loader = dataset_f.validate_loader + +optimizer = optim.SGD( + filter( + lambda p: p.requires_grad, + model.parameters()), + lr=args.lr, + weight_decay=args.weight_decay, + momentum=args.momentum, + nesterov=True) +# optimizer = optim.Adam( +# filter( +# lambda p: p.requires_grad, +# model.parameters()), +# lr=args.lr, +# weight_decay=args.weight_decay) +criterion = nn.CrossEntropyLoss() +transfer_criterion = nn.MSELoss() +if args.sr: + print('\nSparsity Training \n') + trainer = Network_Slimming_Trainer( + model=model, + optimizer=optimizer, + lr=args.lr, + criterion=criterion, + start_epoch=args.start_epoch, + epochs=args.epochs, + cuda=args.cuda, + log_interval=args.log_interval, + train_loader=train_loader, + validate_loader=validate_loader, + root=args.save_path, + penalty=args.p, + ) +elif args.se: + print('\nSE_ResNet Training \n') + trainer = SE_Trainer( + model=model, + optimizer=optimizer, + lr=args.lr, + criterion=criterion, + start_epoch=args.start_epoch, + epochs=args.epochs, + cuda=args.cuda, + log_interval=args.log_interval, + train_loader=train_loader, + validate_loader=validate_loader, + root=args.save_path, + SEBlock=SEBlock + ) + +elif args.fine_tune is not None and args.teacher_model is not None: # other 3 00 01 10 + print('\nTraining with Knowledge Distillation \n') + trainer = Trainer( + model=model, + teacher_model=teacher_model, + optimizer=optimizer, + lr=args.lr, + criterion=criterion, + start_epoch=args.start_epoch, + epochs=args.epochs, + cuda=args.cuda, + log_interval=args.log_interval, + train_loader=train_loader, + validate_loader=validate_loader, + root=args.save_path, + loss_ratio=args.loss_ratio, + transfer_criterion=transfer_criterion, + + ) + +else: + print('\nNormal Training \n') + trainer = Trainer( + model=model, + optimizer=optimizer, + lr=args.lr, + criterion=criterion, + start_epoch=args.start_epoch, + epochs=args.epochs, + cuda=args.cuda, + log_interval=args.log_interval, + train_loader=train_loader, + validate_loader=validate_loader, + root=args.save_path, + + ) +trainer.start()