diff --git a/.gitignore b/.gitignore index b5be513..c97168e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ -NWPU-RESISC45/ +*/NWPU-RESISC45/ checkpoints_*/ +*/venv \ No newline at end of file diff --git a/classification/a.py b/classification/a.py new file mode 100644 index 0000000..c4b3a99 --- /dev/null +++ b/classification/a.py @@ -0,0 +1,67 @@ +def optimize(model, +criterion, +train_data, +labels, +sam=True, +mixup_criterion=None, +labels2=None, +mixup_lam=None, +distiller=None): + if mixup_criterion is not None: + assert len(labels2) == len(labels) + assert mixup_lam is not None + + outputs = model(train_data) + + if mixup_criterion: + train_loss = mixup_criterion(criterion, outputs, labels, labels2, mixup_lam) + elif distiller: + train_loss = distiller(train_data, labels) + else: + train_loss = criterion(outputs, labels) + + #Calculate batch accuracy and accumulate in epoch accuracy + epoch_loss += train_loss / len(train_loader) + output_labels = outputs.argmax(dim=1) + train_acc = (output_labels == train_labels).float().mean() + epoch_accuracy += train_acc / len(train_loader) + + if sam: + train_loss.backward() #Gradient of loss + optimizer.first_step(zero_grad=True) #Perturb weights + outputs = vit(train_data) #Outputs based on perturbed weights + if mixup_criterion: + perturbed_loss = mixup_criterion(criterion, outputs, labels, labels2, mixup_lam) + else: + perturbed_loss = criterion(outputs, train_labels) #Loss with perturbed weights + perturbed_loss.backward()#Gradient of perturbed loss + optimizer.second_step(zero_grad=True) #Unperturb weights and updated weights based on perturbed losses + optimizer.zero_grad() #Set gradients of optimized tensors to zero to prevent gradient accumulation + iteration += 1 + progress_bar.update(1) + + else: + # is_second_order attribute is added by timm on one optimizer + # (adahessian) + loss_scaler.scale(train_loss).backward( + create_graph=( + hasattr(optimizer, "is_second_order") + and optimizer.is_second_order + ) + ) + if optimizer_args.clip_grad is not None: + # unscale the gradients of optimizer's params in-place + loss_scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_( + vit.parameters(), optimizer_args.clip_grad + ) + + n_accum += 1 + + if n_accum == n_batch_accum: + n_accum = 0 + loss_scaler.step(optimizer) + loss_scaler.update() + + iteration += 1 + progress_bar.update(1) \ No newline at end of file diff --git a/classification/data_configs/data_config.json b/classification/data_configs/data_config.json index c486ace..84158d8 100644 --- a/classification/data_configs/data_config.json +++ b/classification/data_configs/data_config.json @@ -5,9 +5,13 @@ "test_files": "test_imagepaths.txt", "label_map": "label_map.json", "number_of_classes": 45, + "mixup": false, + "mixup_alpha": 0.7, "transform_ops_train": { "RandomResizedCrop": 224, "RandomHorizontalFlip": null, + "RandomVerticalFlip": null, + "RandAugment": null, "Normalize": { "Mean": [0.485, 0.456, 0.406], "Std": [0.229, 0.224, 0.225] diff --git a/classification/mixup.py b/classification/mixup.py new file mode 100644 index 0000000..fb7a23e --- /dev/null +++ b/classification/mixup.py @@ -0,0 +1,23 @@ +import numpy as np +import torch + +def mixup_data(X, y, alpha=1): + """Implement mixup. + + Returns mixed inputs, pairs of targets, and lambda, + which represents fraction of mixup. + + https://arxiv.org/abs/1710.09412 + """ + lam = np.random.beta(alpha, alpha) if alpha > 0 else 1 + + batch_size = X.size()[0] + index = torch.randperm(batch_size).cuda() if torch.cuda.is_available() else torch.randperm(batch_size) + + mixed_X = lam*X + (1-lam)*X[index, :] + y_1, y_2 = y, y[index] + + return mixed_X, y_1, y_2, lam + +def mixup_criterion(criterion, pred, y_1, y_2, lam): + return lam*criterion(pred, y_1) + (1-lam)*criterion(pred, y_2) \ No newline at end of file diff --git a/classification/requirements.txt b/classification/requirements.txt index 73fd313..e75b036 100644 --- a/classification/requirements.txt +++ b/classification/requirements.txt @@ -6,3 +6,7 @@ psutil numpy scikit-learn tqdm +wandb +pynvml +ninja +transformers \ No newline at end of file diff --git a/classification/sam.py b/classification/sam.py new file mode 100644 index 0000000..dfe1a3f --- /dev/null +++ b/classification/sam.py @@ -0,0 +1,116 @@ +import torch + +class SAM(torch.optim.Optimizer): + """Sharpness aware minimization. + + Parameters + ---------- + params: + model parameters + base_optimizer: torch.optim.Optimizer + e.g. ADAMW or SGD + rho: float + hyperparameter for SAM - perturbation strength (default 0.05) + gsam_alpha: float + hyperparameter for GSAM (default 0.05) + GSAM: bool + whether or not to use Surrogate Gap Sharpness Aware Minimization + adaptive: bool + whether or not to use Adaptive SAM + + Public Methods + -------------- + first_step(self, zero_grad): + Perturb weights. + unperturb(self, zero_grad): + Unperturb weights. + second_step(self, zero_grad): + Unperturb and update weights. + load_state_dict(self, state_dict): + Copies parameters and buffers from state_dict + into this module and its descendants. + + """ + def __init__(self, params, base_optimizer, rho=0.05, gsam_alpha=0.05, GSAM=False, adaptive=False, **kwargs): + assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" + + defaults = dict(adaptive=adaptive, **kwargs) + super(SAM, self).__init__(params, defaults) + + self.rho = rho + self.alpha = gsam_alpha + self.GSAM = GSAM + self.base_optimizer = base_optimizer(self.param_groups, **kwargs) + self.param_groups = self.base_optimizer.param_groups + self.defaults.update(self.base_optimizer.defaults) + + @torch.no_grad() + def first_step(self, zero_grad=False): + """Perturb weights.""" + grad_norm = self._grad_norm() + for group in self.param_groups: #Iterate over parameters/weights + scale = self.rho / (grad_norm + 1e-12) #Perturbation factor + + for p in group["params"]: + if p.grad is None: continue + self.state[p]["old_p"] = p.data.clone() #Save old parameter to unperturb later + if self.GSAM: self.state[p]["old_p_grad"] = p.grad.data.clone() #Save old gradient (if GSAM) + e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) #Calculate perturbation + p.add_(e_w) #Climb to the local maximum "w + e(w)" - perturb + if zero_grad: self.zero_grad() + + @torch.no_grad() + def unperturb(self, zero_grad=False): + """Return to old parameters - remove perturbation.""" + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: continue + p.data = self.state[p]["old_p"] #Get back to "w" from "w + e(w)" + if zero_grad: self.zero_grad() + + @torch.no_grad() + def second_step(self, zero_grad=False): + """Unperturb weights and then update. + + If GSAM, decompose gradients before unperturbing weights.""" + if self.GSAM: + self._decompose_grad() + self.unperturb() + + self.base_optimizer.step() # do the actual "sharpness-aware" update + + if zero_grad: self.zero_grad() + + def _grad_norm(self): + shared_device = self.param_groups[0]["params"][0].device #Put everything on the same device, in case of model parallelism + norm = torch.norm( + torch.stack([ + ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) + for group in self.param_groups for p in group["params"] + if p.grad is not None + ]), + p=2 + ) + return norm + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + self.base_optimizer.param_groups = self.param_groups + + def _decompose_grad(self): + """Decompose gradient of unperturbed loss into directions parallel and + perpendicular to the gradients of the perturbed losses, for GSAM. + Subtract perpendicular component from perturbed loss gradients. + """ + for group in self.param_groups: + for p in group['params']: + if p.grad is None: continue + old_grad = self.state[p]["old_p_grad"] + if old_grad is None: continue + #Find factor of component parallel to perturbed loss + #Take dot product between two vectors. + a = torch.dot(p.grad.data.view(-1), old_grad.view(-1))/torch.norm(p.grad.data)**2 + perp = old_grad - a*p.grad.data #Component perpendicular to perturbed loss = vector - parallel component + norm_perp = perp / torch.norm(perp) #Normalise perpendicular component + #Subtract perp component from perturbed loss gradients, with factor alpha. + p.grad.data.sub_(self.alpha * norm_perp) \ No newline at end of file diff --git a/classification/train.py b/classification/train.py index 53dad49..547e2ef 100644 --- a/classification/train.py +++ b/classification/train.py @@ -14,6 +14,8 @@ from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from tqdm.auto import tqdm +from transformers import AdamW from utils.data_loader import Resisc45Loader from utils.models import get_models from utils.models import get_optimizer_args @@ -25,7 +27,8 @@ from utils.utils import init_distributed from utils.utils import parse_config from utils.utils import seed_everything - +from sam import SAM +from mixup import mixup_data, mixup_criterion def validation(val_loader, device, criterion, iteration, vit, distiller=None): total_val_loss = 0 @@ -54,6 +57,85 @@ def validation(val_loader, device, criterion, iteration, vit, distiller=None): def train_deit(rank, num_gpus, config): + + def optimize(model, + criterion, + train_data, + labels, + sam=True, + mixup_criterion=None, + labels2=None, + mixup_lam=None, + distiller=None): + #If using mixup, we require a second set of labels, + #and a value for lambda (mixup variable) + if mixup_criterion is not None: + assert len(labels2) == len(labels) + assert mixup_lam is not None + + nonlocal optimizer + nonlocal epoch_loss + nonlocal epoch_accuracy + nonlocal iteration + nonlocal progress_bar + nonlocal train_loader_len + nonlocal n_accum + + outputs = model(train_data) + + if mixup_criterion: + train_loss = mixup_criterion(criterion, outputs, labels, labels2, mixup_lam) + elif distiller: + train_loss = distiller(train_data, labels) + else: + train_loss = criterion(outputs, labels) + + #Calculate batch accuracy and accumulate in epoch accuracy + epoch_loss += train_loss / train_loader_len + output_labels = outputs.argmax(dim=1) + train_acc = (output_labels == labels).float().mean() + epoch_accuracy += train_acc / train_loader_len + + if sam: + train_loss.backward() #Gradient of loss + optimizer.first_step(zero_grad=True) #Perturb weights + outputs = model(train_data) #Outputs based on perturbed weights + if mixup_criterion: + perturbed_loss = mixup_criterion(criterion, outputs, labels, labels2, mixup_lam) + else: + perturbed_loss = criterion(outputs, labels) #Loss with perturbed weights + perturbed_loss.backward()#Gradient of perturbed loss + optimizer.second_step(zero_grad=True) #Unperturb weights and updated weights based on perturbed losses + optimizer.zero_grad() #Set gradients of optimized tensors to zero to prevent gradient accumulation + iteration += 1 + progress_bar.update(1) + + else: + # is_second_order attribute is added by timm on one optimizer + # (adahessian) + loss_scaler.scale(train_loss).backward( + create_graph=( + hasattr(optimizer, "is_second_order") + and optimizer.is_second_order + ) + ) + if optimizer_args.clip_grad is not None: + # unscale the gradients of optimizer's params in-place + loss_scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_( + vit.parameters(), optimizer_args.clip_grad + ) + + n_accum += 1 + + if n_accum == n_batch_accum: + n_accum = 0 + loss_scaler.step(optimizer) + loss_scaler.update() + + iteration += 1 + progress_bar.update(1) + torch.backends.cudnn.enabled = True # more consistent performance at cost of some nondeterminism torch.backends.cudnn.benchmark = True @@ -73,6 +155,12 @@ def train_deit(rank, num_gpus, config): global_batch_size = train_config["global_batch_size"] device = "cuda" if torch.cuda.is_available() else "cpu" pretrained_backbone = train_config["pretrained_backbone"] + sam = train_config["sam"] + rho = train_config["sam_rho"] + lr = train_config["lr"] + + mixup = data_config["mixup"] + mixup_alpha = data_config["mixup_alpha"] seed_everything(seed) @@ -160,7 +248,11 @@ def train_deit(rank, num_gpus, config): if distiller is not None: optimizer = create_optimizer(optimizer_args, distiller) else: - optimizer = create_optimizer(optimizer_args, vit) + if sam: + base_optimizer = AdamW + optimizer = SAM(vit.parameters(), base_optimizer, rho=rho, lr=lr) + else: + optimizer = create_optimizer(optimizer_args, vit) lr_scheduler, _ = create_scheduler(optimizer_args, optimizer) loss_scaler = torch.cuda.amp.GradScaler() # loss criterion used only when model trained without distillation and @@ -186,6 +278,8 @@ def train_deit(rank, num_gpus, config): n_accum = 0 epoch_last_val_loss = 0 epoch_last_val_accuracy = 0 + progress_bar = tqdm(range((epochs-epoch_offset)*len(train_loader))) + train_loader_len = len(train_loader) # Train loop for epoch in range(epoch_offset, epochs): epoch_loss = 0 @@ -215,7 +309,7 @@ def train_deit(rank, num_gpus, config): ) if ( - iteration % iters_per_val == 0 + iteration % len(train_loader) == 0 and n_accum == 0 and rank == 0 ): @@ -240,47 +334,23 @@ def train_deit(rank, num_gpus, config): train_imgs = train_imgs.to(device) train_labels = train_labels.to(device) - outputs = vit(train_imgs) - # calculate batch loss and accumulate in epoch loss - if distiller is not None: - train_loss = distiller(train_imgs, train_labels) + if mixup: + mixed_inputs, targets_a, targets_b, lam = mixup_data(train_imgs, + train_labels, alpha=mixup_alpha) + optimize(model=vit, + criterion=criterion, + train_data=mixed_inputs, + labels=targets_a, + sam=sam, + mixup_criterion=mixup_criterion, + labels2 = targets_b, + mixup_lam=lam) else: - train_loss = criterion(outputs, train_labels) - epoch_loss += train_loss / len(train_loader) - # calculate batch accuracy and accumulate in epoch accuracy - output_labels = outputs.argmax(dim=1) - train_acc = (output_labels == train_labels).float().mean() - epoch_accuracy += train_acc / len(train_loader) - - # is_second_order attribute is added by timm on one optimizer - # (adahessian) - loss_scaler.scale(train_loss).backward( - create_graph=( - hasattr(optimizer, "is_second_order") - and optimizer.is_second_order - ) - ) - if optimizer_args.clip_grad is not None: - # unscale the gradients of optimizer's params in-place - loss_scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_( - vit.parameters(), optimizer_args.clip_grad - ) - - n_accum += 1 - - if n_accum == n_batch_accum: - n_accum = 0 - loss_scaler.step(optimizer) - loss_scaler.update() - - iteration += 1 - - if rank == 0: - print( - f"Iteration {iteration}:\tloss={train_loss:.4f}" - f"\tacc={train_acc:.4f}" - ) + optimize(model=vit, + criterion=criterion, + train_data=train_imgs, + labels=train_labels, + sam=sam) lr_scheduler.step(epoch) diff --git a/classification/train_configs/vit_tiny.json b/classification/train_configs/vit_tiny.json index 3655800..d43abf7 100644 --- a/classification/train_configs/vit_tiny.json +++ b/classification/train_configs/vit_tiny.json @@ -28,7 +28,9 @@ "patience_epochs": 5, "decay_rate": 0.1, "distributed": true, - "pretrained_backbone": "vit_tiny_patch16_224" + "pretrained_backbone": "vit_tiny_patch16_224", + "sam": false, + "sam_rho": 0.05 }, "data_config_path": "data_configs/data_config.json",