Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__pycache__/

NWPU-RESISC45/
*/NWPU-RESISC45/
checkpoints_*/
*/venv
67 changes: 67 additions & 0 deletions classification/a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
def optimize(model,
criterion,
train_data,
labels,
sam=True,
mixup_criterion=None,
labels2=None,
mixup_lam=None,
distiller=None):
if mixup_criterion is not None:
assert len(labels2) == len(labels)
assert mixup_lam is not None

outputs = model(train_data)

if mixup_criterion:
train_loss = mixup_criterion(criterion, outputs, labels, labels2, mixup_lam)
elif distiller:
train_loss = distiller(train_data, labels)
else:
train_loss = criterion(outputs, labels)

#Calculate batch accuracy and accumulate in epoch accuracy
epoch_loss += train_loss / len(train_loader)
output_labels = outputs.argmax(dim=1)
train_acc = (output_labels == train_labels).float().mean()
epoch_accuracy += train_acc / len(train_loader)

if sam:
train_loss.backward() #Gradient of loss
optimizer.first_step(zero_grad=True) #Perturb weights
outputs = vit(train_data) #Outputs based on perturbed weights
if mixup_criterion:
perturbed_loss = mixup_criterion(criterion, outputs, labels, labels2, mixup_lam)
else:
perturbed_loss = criterion(outputs, train_labels) #Loss with perturbed weights
perturbed_loss.backward()#Gradient of perturbed loss
optimizer.second_step(zero_grad=True) #Unperturb weights and updated weights based on perturbed losses
optimizer.zero_grad() #Set gradients of optimized tensors to zero to prevent gradient accumulation
iteration += 1
progress_bar.update(1)

else:
# is_second_order attribute is added by timm on one optimizer
# (adahessian)
loss_scaler.scale(train_loss).backward(
create_graph=(
hasattr(optimizer, "is_second_order")
and optimizer.is_second_order
)
)
if optimizer_args.clip_grad is not None:
# unscale the gradients of optimizer's params in-place
loss_scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
vit.parameters(), optimizer_args.clip_grad
)

n_accum += 1

if n_accum == n_batch_accum:
n_accum = 0
loss_scaler.step(optimizer)
loss_scaler.update()

iteration += 1
progress_bar.update(1)
4 changes: 4 additions & 0 deletions classification/data_configs/data_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
"test_files": "test_imagepaths.txt",
"label_map": "label_map.json",
"number_of_classes": 45,
"mixup": false,
"mixup_alpha": 0.7,
"transform_ops_train": {
"RandomResizedCrop": 224,
"RandomHorizontalFlip": null,
"RandomVerticalFlip": null,
"RandAugment": null,
"Normalize": {
"Mean": [0.485, 0.456, 0.406],
"Std": [0.229, 0.224, 0.225]
Expand Down
23 changes: 23 additions & 0 deletions classification/mixup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np
import torch

def mixup_data(X, y, alpha=1):
"""Implement mixup.

Returns mixed inputs, pairs of targets, and lambda,
which represents fraction of mixup.

https://arxiv.org/abs/1710.09412
"""
lam = np.random.beta(alpha, alpha) if alpha > 0 else 1

batch_size = X.size()[0]
index = torch.randperm(batch_size).cuda() if torch.cuda.is_available() else torch.randperm(batch_size)

mixed_X = lam*X + (1-lam)*X[index, :]
y_1, y_2 = y, y[index]

return mixed_X, y_1, y_2, lam

def mixup_criterion(criterion, pred, y_1, y_2, lam):
return lam*criterion(pred, y_1) + (1-lam)*criterion(pred, y_2)
4 changes: 4 additions & 0 deletions classification/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@ psutil
numpy
scikit-learn
tqdm
wandb
pynvml
ninja
transformers
116 changes: 116 additions & 0 deletions classification/sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import torch

class SAM(torch.optim.Optimizer):
"""Sharpness aware minimization.

Parameters
----------
params:
model parameters
base_optimizer: torch.optim.Optimizer
e.g. ADAMW or SGD
rho: float
hyperparameter for SAM - perturbation strength (default 0.05)
gsam_alpha: float
hyperparameter for GSAM (default 0.05)
GSAM: bool
whether or not to use Surrogate Gap Sharpness Aware Minimization
adaptive: bool
whether or not to use Adaptive SAM

Public Methods
--------------
first_step(self, zero_grad):
Perturb weights.
unperturb(self, zero_grad):
Unperturb weights.
second_step(self, zero_grad):
Unperturb and update weights.
load_state_dict(self, state_dict):
Copies parameters and buffers from state_dict
into this module and its descendants.

"""
def __init__(self, params, base_optimizer, rho=0.05, gsam_alpha=0.05, GSAM=False, adaptive=False, **kwargs):
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

defaults = dict(adaptive=adaptive, **kwargs)
super(SAM, self).__init__(params, defaults)

self.rho = rho
self.alpha = gsam_alpha
self.GSAM = GSAM
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
self.param_groups = self.base_optimizer.param_groups
self.defaults.update(self.base_optimizer.defaults)

@torch.no_grad()
def first_step(self, zero_grad=False):
"""Perturb weights."""
grad_norm = self._grad_norm()
for group in self.param_groups: #Iterate over parameters/weights
scale = self.rho / (grad_norm + 1e-12) #Perturbation factor

for p in group["params"]:
if p.grad is None: continue
self.state[p]["old_p"] = p.data.clone() #Save old parameter to unperturb later
if self.GSAM: self.state[p]["old_p_grad"] = p.grad.data.clone() #Save old gradient (if GSAM)
e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) #Calculate perturbation
p.add_(e_w) #Climb to the local maximum "w + e(w)" - perturb
if zero_grad: self.zero_grad()

@torch.no_grad()
def unperturb(self, zero_grad=False):
"""Return to old parameters - remove perturbation."""
for group in self.param_groups:
for p in group["params"]:
if p.grad is None: continue
p.data = self.state[p]["old_p"] #Get back to "w" from "w + e(w)"
if zero_grad: self.zero_grad()

@torch.no_grad()
def second_step(self, zero_grad=False):
"""Unperturb weights and then update.

If GSAM, decompose gradients before unperturbing weights."""
if self.GSAM:
self._decompose_grad()
self.unperturb()

self.base_optimizer.step() # do the actual "sharpness-aware" update

if zero_grad: self.zero_grad()

def _grad_norm(self):
shared_device = self.param_groups[0]["params"][0].device #Put everything on the same device, in case of model parallelism
norm = torch.norm(
torch.stack([
((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
for group in self.param_groups for p in group["params"]
if p.grad is not None
]),
p=2
)
return norm

def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
self.base_optimizer.param_groups = self.param_groups

def _decompose_grad(self):
"""Decompose gradient of unperturbed loss into directions parallel and
perpendicular to the gradients of the perturbed losses, for GSAM.
Subtract perpendicular component from perturbed loss gradients.
"""
for group in self.param_groups:
for p in group['params']:
if p.grad is None: continue
old_grad = self.state[p]["old_p_grad"]
if old_grad is None: continue
#Find factor of component parallel to perturbed loss
#Take dot product between two vectors.
a = torch.dot(p.grad.data.view(-1), old_grad.view(-1))/torch.norm(p.grad.data)**2
perp = old_grad - a*p.grad.data #Component perpendicular to perturbed loss = vector - parallel component
norm_perp = perp / torch.norm(perp) #Normalise perpendicular component
#Subtract perp component from perturbed loss gradients, with factor alpha.
p.grad.data.sub_(self.alpha * norm_perp)
Loading