From a0a89347a3ea7abb0f2a8c44665eee77f7c17621 Mon Sep 17 00:00:00 2001 From: Doctor G Date: Fri, 27 Feb 2026 01:26:40 +0000 Subject: [PATCH] Add optional LR scheduler support via train_params Allow train_params['lr_scheduler'] to accept a dict specifying a scheduler type and its parameters. Supports reduce_on_plateau, step, and cosine schedulers. Defaults to None (no scheduler), preserving existing behavior. Co-authored-by: Ona --- src/grelu/lightning/__init__.py | 75 +++++++++++++++++++++++++++++++-- 1 file changed, 71 insertions(+), 4 deletions(-) diff --git a/src/grelu/lightning/__init__.py b/src/grelu/lightning/__init__.py index a1e3473..e58546f 100644 --- a/src/grelu/lightning/__init__.py +++ b/src/grelu/lightning/__init__.py @@ -57,6 +57,7 @@ "class_weights": None, "total_weight": None, "accumulate_grad_batches": 1, + "lr_scheduler": None, } @@ -388,19 +389,85 @@ def on_test_epoch_end(self) -> None: k: v.detach().cpu().numpy() for k, v in test_metrics.items() } - def configure_optimizers(self) -> None: + def configure_optimizers(self): """ - Configure oprimizer for training + Configure optimizer and optional learning rate scheduler for training. + + The optimizer is selected via ``train_params["optimizer"]`` (``"adam"`` or + ``"sgd"``). + + An LR scheduler can be enabled by setting ``train_params["lr_scheduler"]`` + to a dict with a ``"type"`` key and any scheduler-specific parameters:: + + "lr_scheduler": { + "type": "reduce_on_plateau", # or "step", "cosine" + "patience": 3, + "factor": 0.5, + "monitor": "val_loss", + } + + Supported scheduler types: + + - ``"reduce_on_plateau"`` — :class:`torch.optim.lr_scheduler.ReduceLROnPlateau`. + Accepts ``factor``, ``patience``, and ``monitor``. + - ``"step"`` — :class:`torch.optim.lr_scheduler.StepLR`. + Accepts ``step_size`` and ``gamma``. + - ``"cosine"`` — :class:`torch.optim.lr_scheduler.CosineAnnealingLR`. + Accepts ``T_max`` (defaults to ``train_params["max_epochs"]``). + + When ``train_params["lr_scheduler"]`` is ``None`` (the default), no + scheduler is used and the bare optimizer is returned. """ + # Build optimizer if self.train_params["optimizer"] == "adam": - return optim.Adam(self.parameters(), lr=self.train_params["lr"]) + optimizer = optim.Adam(self.parameters(), lr=self.train_params["lr"]) elif self.train_params["optimizer"] == "sgd": - return optim.SGD( + optimizer = optim.SGD( self.parameters(), lr=self.train_params["lr"], momentum=0.9 ) else: raise Exception("Unknown optimizer") + # Build LR scheduler if requested + scheduler_params = self.train_params.get("lr_scheduler") + if scheduler_params is None: + return optimizer + + stype = scheduler_params["type"] + if stype == "reduce_on_plateau": + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + factor=scheduler_params.get("factor", 0.1), + patience=scheduler_params.get("patience", 10), + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": scheduler_params.get("monitor", "val_loss"), + }, + } + elif stype == "step": + scheduler = optim.lr_scheduler.StepLR( + optimizer, + step_size=scheduler_params["step_size"], + gamma=scheduler_params.get("gamma", 0.1), + ) + elif stype == "cosine": + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=scheduler_params.get( + "T_max", self.train_params["max_epochs"] + ), + ) + else: + raise Exception( + f"Unknown lr_scheduler type '{stype}'. " + "Supported: reduce_on_plateau, step, cosine" + ) + + return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler}} + def count_params(self) -> int: """ Number of gradient enabled parameters in the model