Conversation
…base class - Change Step() to virtual with default implementation - Add pure virtual ComputeLR() for subclasses to implement. - Adapt test helpers (IdentityScheduler, LinearDecayScheduler) to implement ComputeLR() instead of Step(). - All existing tests pass without behavioral changes. BREAKING CHANGE: Subclasses must implement ComputeLR() instead of Step().
…t and update all tests to use Create<T>() factory method.
…entialLR - enhance LRScheduler with chained and closed form learning rate methods - adapt methods(Step, InitialStep, GetClosedFormLR, GetChainedFormLR) to match PyTorch‘s design - add tests for consistency - refactor LinearLR: add end_factor, and rename this class - add SequentialLR InitialStep and UndoChildInitialSteps BREAKING CHANGE: Subclasses must implement GetClosedFormLR instead of ComputeLR(). Should use LinearLR instead of LinearwarmupLR.
- Add LRSchedulerConfig struct with parameters for all basic schedulers(constant, linear, step) - Add CreateLRScheduler() factory function - Support automatic warmup wrapping via SequentialLR when warmup_steps > 0 - Adapt test files
…tial, Chained, and Lambda)
…ommon total_iters
- Add gflags: --lr_scheduler, --warmup_steps, --step_size, --gamma, --start_factor, --end_factor, --lr_total_iters, --total_steps - Replace nullptr scheduler with factory-created scheduler - Move scheduler.Step() after optimizer.Step() in both DP and PP paths - Replace hardcoded FLAGS_learning_rate in log with scheduler->GetLR()
example/gpt2/main.cc
Outdated
| size_t used_mb = 0, reserved_mb = 0; | ||
| std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); | ||
|
|
||
| const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate); |
There was a problem hiding this comment.
scheduler 在前面已经 Step 过了,所以这里 GetLR() 语义上是”下一步要用到的 lr“;而我们这里想打印的是每一步实际用到的 lr,所以这里的逻辑需要修改下。llama3 部分的 main.cc 里同理。
example/llama3/main.cc
Outdated
| size_t used_mb = 0, reserved_mb = 0; | ||
| std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device); | ||
|
|
||
| const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate); |
| std::vector<std::shared_ptr<Tensor>> params_; | ||
| float learning_rate_ = 0.0f; | ||
| float initial_learning_rate_ = 0.0f; | ||
| bool initial_lr_set_ = false; |
There was a problem hiding this comment.
这部分比较冗余。optimizer 里面可以只存有代表当前学习率的 learning_rate_,不需要额外存 initial lr 的状态;语义上初始学习率可以仅存在 lr scheduler 里(你是实际上已经这样做了,存在 lr scheduler 的 base_lr)。
There was a problem hiding this comment.
此处为对齐PyTorch初始化时的设置( 源码链接 ),
PyTorch在对调度器进行初始化时,会访问其关联优化器的参数列表,并进行setdefault,设置initial_rate_,对于首次被关联的优化器,将现在的学习率设置为initial_lr,对于非首次关联的调度器,返回现有值。
目前仅能想到作用为,可保证如果有多个调度器关联同一优化器声明(ChainedScheduler或SequentialLR等),他们的base_lr_均为第一个调度器进行初始化时优化器的学习率。暂不清楚其他应用场景,但出于与PyTorch保持一致,增设了相关参数,如果只涉及ChainedScheduler或SequentialLR的话,确实有其他替代方案,是否需要更改?
infini_train/include/lr_scheduler.h
Outdated
|
|
||
| std::shared_ptr<Optimizer> optimizer_; | ||
| int64_t last_step_; | ||
| float current_lr_; |
There was a problem hiding this comment.
current_lr_ 似乎也有点冗余,语义上 current_lr_ 和 optimizer_->GetLearningRate() 的值在任何时候应等价,现在在你的设计里看到这二者存在各自分开存且混用的状态(读完发现目前的 current_lr_ 像是 optimizer_->GetLearningRate() 的一个副本);目前的数值正确性上你处理的没问题,但是这种设计交给后人来扩展的时候很可能带来歧义。
建议针对“当前学习率”只保留唯一真状态来源,要么就全程由 optimizer_->GetLearningRate() 跟踪,lr scheduler 里面就不存 current lr 了;要么就由 lr scheduler 跟踪,每次计算完再 set 回 optimizer。个人认为前者较合适。
There was a problem hiding this comment.
已修改,由于需要调度器具备恢复训练的能力,而如SequentialLR或ChainedScheduler等不支持closed-form计算,无法根据base_lr和last_epoch快速得到学习率,因此保留接口仅用于学习率恢复,并调整命名为recover_lr避免混淆。
infini_train/src/lr_scheduler.cc
Outdated
|
|
||
| void LRScheduler::ApplyLR(float lr) { | ||
| current_lr_ = lr; | ||
| optimizer_->SetLearningRate(current_lr_); |
There was a problem hiding this comment.
承接上面所说的,在你的设计中一方面看到有 optimizer_->SetLearningRate(current_lr_); 这种调用,另一方面又有 current_lr_ = optimizer_->GetLearningRate();,二者可能会存在谁因谁果的混淆,所以建议保持设计上语义的一致性。
infini_train/src/lr_scheduler.cc
Outdated
| scheduler->Step(); | ||
| } | ||
|
|
||
| current_lr_ = optimizer_->GetLearningRate(); |
There was a problem hiding this comment.
承接上面所说的,在你的设计中一方面看到有 optimizer_->SetLearningRate(current_lr_); 这种调用,另一方面又有 current_lr_ = optimizer_->GetLearningRate();,二者可能会存在谁因谁果的混淆,所以建议保持设计上语义的一致性。
| } else if (last_step_ < total_iters_) { | ||
| return lr; | ||
| } else if (last_step_ == total_iters_) { | ||
| return lr / factor_; |
There was a problem hiding this comment.
个别超参的值由于是由 cli 用户传入,所以需要加一下非法检查。以此处为例,factor 应该是 (0, 1) 范围内的,不然可能会存在除零的非法值。torch 实现中也在构造函数中做了检查,参考:https://github.com/pytorch/pytorch/blob/08840d08a02eead8edf22406a53e5691c9a89c9a/torch/optim/lr_scheduler.py#L813
另外,以我看到的,还有 StepLR 没检查 step_size > 0,LinearLR 没检查两个 factor 以及 total_iters 等。建议通篇 check 一下。
infini_train/include/lr_scheduler.h
Outdated
| void LoadState(const StateDict &state) override; | ||
|
|
||
| protected: | ||
| float GetClosedFormLR() const override { return current_lr_; } |
There was a problem hiding this comment.
这块语义上不太对,我仔细看了下 torch 里面的实现,GetClosedFormLR 对标 torch 里提供的 get_closed_form_lr 的接口的话, 实际是想实现一个“给定 base_lr、last_step 以及其他超参,然后可以通过公式算出当前 lr 的 function”。这个虽然数值上确实等于你现在提供的 current_lr,但是逻辑上的代码不应该直接返回缓存的 current_lr_ 就完事,而是应该给一个计算公式。
另外,torch 里提供的 _get_closed_form_lr 的接口,最终实际上是用于 step(int epoch) 这个 function 的,如果对应的 LRScheduler 派生类实现了这个 _get_closed_form_lr,就代表其支持 closed form 的跳步语义,然后 step(epoch) 会直接由提供的 function 计算出 current lr。而 torch 里面的 SequentialLR 派生类没有实现这个 function。
考虑到你这边的 GetClosedFormLR 定义为虚函数,要求所有派生类必须实现,我建议是在这里加上一个 // FIXME 的注释说明一下这一点,目前暂时先以返回一个 current lr 来 hack 实现,而不是提供了 closed-form 计算方法。
infini_train/include/lr_scheduler.h
Outdated
| }; | ||
|
|
||
| } // namespace lr_schedulers | ||
| } // namespace infini_train No newline at end of file |
There was a problem hiding this comment.
format 规范上,end of file 需要有一个 newline,后续也有几个文件存在这个问题
- it now only be used for learning rate recovery when using loadstate
There was a problem hiding this comment.
Pull request overview
This PR introduces a learning-rate scheduler system to infini_train, integrates it with optimizers (including distributed optimizer), and adds standalone C++ test executables plus example CLI wiring to exercise the new schedulers.
Changes:
- Add
LRSchedulerbase + concrete schedulers (ConstantLR/StepLR/LinearLR/LambdaLR/SequentialLR/ChainedScheduler) and aCreateLRSchedulerfactory. - Extend
Optimizerwith runtime-settable learning rate and initial learning rate tracking; propagate LR toDistributedOptimizer. - Add scheduler coverage tests and wire scheduler flags into
example/gpt2andexample/llama3; register new test executables in CMake.
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
infini_train/include/lr_scheduler.h |
Declares scheduler APIs, configs, and concrete scheduler types. |
infini_train/src/lr_scheduler.cc |
Implements scheduler logic, factory creation, state save/load, sequential/chained behavior. |
infini_train/include/optimizer.h |
Adds LR getters/setters + initial LR tracking to support schedulers. |
infini_train/src/optimizer.cc |
Implements optimizer LR plumbing and updates SGD/Adam to use base LR storage. |
infini_train/include/nn/parallel/ddp/distributed_optimizer.h |
Overrides LR get/set for distributed optimizer so schedulers affect the real base optimizer. |
infini_train/src/nn/parallel/ddp/distributed_optimizer.cc |
Implements LR propagation to/from the wrapped base optimizer. |
example/gpt2/main.cc |
Adds scheduler CLI flags and steps the scheduler during training. |
example/llama3/main.cc |
Adds scheduler CLI flags and steps the scheduler during training. |
test/lr_scheduler/test_helpers.h |
Shared minimal test helpers/macros for scheduler tests. |
test/lr_scheduler/test_*.cc |
Adds functional + state + validation tests for schedulers. |
CMakeLists.txt |
Adds new scheduler test executables to the build. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| LRSchedulerConfig linear_config = { | ||
| .type = "linear", | ||
| .linear_start_factor = 1e-8f, | ||
| .linear_end_factor = 1.0f, | ||
| .linear_total_iters = 3, | ||
| }; | ||
| auto linear = CreateLRScheduler(opt, linear_config); | ||
| LRSchedulerConfig constant_config = { | ||
| .type = "constant", | ||
| .constant_factor = 1.0f, | ||
| .constant_total_iters = 100, | ||
| }; | ||
| auto constant = CreateLRScheduler(opt, constant_config); |
There was a problem hiding this comment.
Several schedulers are created but never used (linear, constant), which is dead code and can trigger compiler warnings or confuse readers about what's being tested. Remove these unused variables (or explicitly cast to void if construction side-effects are what you want to test).
| auto linear = CreateLRScheduler(opt, { | ||
| .type = "linear", | ||
| .linear_start_factor = 1e-8f, | ||
| .linear_end_factor = 1.0f, | ||
| .linear_total_iters = 3, | ||
| }); | ||
| auto step_lr = CreateLRScheduler(opt, { | ||
| .type = "step", | ||
| .step_size = 3, | ||
| .step_gamma = 0.1f, | ||
| }); | ||
| auto Lambda = CreateLRScheduler(opt, { | ||
| .type = "lambda", | ||
| .lambda_fn = [](int64_t step) { return 1.0f - 0.1f * step; }, | ||
| }); |
There was a problem hiding this comment.
linear and Lambda are created but never used in this test case, which is dead code and can lead to misleading tests / unused-variable warnings. Remove them or (void)-mark them if the intent is to validate construction only.
| auto linear = CreateLRScheduler(opt, { | |
| .type = "linear", | |
| .linear_start_factor = 1e-8f, | |
| .linear_end_factor = 1.0f, | |
| .linear_total_iters = 3, | |
| }); | |
| auto step_lr = CreateLRScheduler(opt, { | |
| .type = "step", | |
| .step_size = 3, | |
| .step_gamma = 0.1f, | |
| }); | |
| auto Lambda = CreateLRScheduler(opt, { | |
| .type = "lambda", | |
| .lambda_fn = [](int64_t step) { return 1.0f - 0.1f * step; }, | |
| }); | |
| auto step_lr = CreateLRScheduler(opt, { | |
| .type = "step", | |
| .step_size = 3, | |
| .step_gamma = 0.1f, | |
| }); |
| void SequentialLR::Step() { | ||
| ++last_step_; | ||
| size_t idx = std::upper_bound(milestones_.begin(), milestones_.end(), last_step_) - milestones_.begin(); | ||
|
|
There was a problem hiding this comment.
SequentialLR::Step() uses std::upper_bound, but this translation unit doesn't include <algorithm>. Relying on indirect includes is brittle and can fail to compile depending on standard library headers; add #include <algorithm> in this file (or in the header that declares this usage).
| float ConstantLR::GetChainedFormLR() const { | ||
| const float lr = optimizer_->GetLearningRate(); | ||
| if (last_step_ == 0) { | ||
| return lr * factor_; | ||
| } else if (last_step_ < total_iters_) { | ||
| return lr; | ||
| } else if (last_step_ == total_iters_) { | ||
| return lr / factor_; | ||
| } | ||
| return lr; |
There was a problem hiding this comment.
ConstantLR::GetChainedFormLR() divides by factor_ when last_step_ == total_iters_. Since the constructor allows factor_ == 0, this can cause a division-by-zero at runtime. Either disallow factor_ == 0 (CHECK_GT) or handle the factor_ == 0 transition without dividing.
| for (const auto &sub_config : config.sequential_configs) { | ||
| auto sub_sched = CreateLRScheduler(opt, sub_config); | ||
| if (sub_sched) { | ||
| schedulers.push_back(sub_sched); | ||
| } | ||
| } |
There was a problem hiding this comment.
CreateLRScheduler() filters out sub-schedulers when a sub-config has type == "none", but it keeps sequential_milestones unchanged. This can make milestones.size() != schedulers.size() - 1 and hard-fail in SequentialLR's constructor. Consider rejecting none inside sequential_configs, or keep placeholders / adjust milestones so the sizes stay consistent.
| for (const auto &sub_config : config.sequential_configs) { | |
| auto sub_sched = CreateLRScheduler(opt, sub_config); | |
| if (sub_sched) { | |
| schedulers.push_back(sub_sched); | |
| } | |
| } | |
| for (const auto &sub_config : config.sequential_configs) { | |
| if (sub_config.type == "none") { | |
| LOG(FATAL) << "LR scheduler type \"none\" is not allowed inside sequential_configs."; | |
| } | |
| auto sub_sched = CreateLRScheduler(opt, sub_config); | |
| if (sub_sched) { | |
| schedulers.push_back(sub_sched); | |
| } | |
| } | |
| if (!schedulers.empty() && milestones.size() != schedulers.size() - 1) { | |
| LOG(FATAL) << "SequentialLR requires milestones.size() == schedulers.size() - 1, but got " | |
| << milestones.size() << " milestones and " << schedulers.size() << " schedulers."; | |
| } |
| auto warmup_scheduler = LRScheduler::Create<lr_schedulers::LinearLR>(optimizer, | ||
| /*start_factor=*/config.warmup_start_factor, | ||
| /*end_factor=*/config.warmup_end_factor, | ||
| /*total_iters=*/config.warmup_steps); | ||
|
|
||
| auto main_scheduler = create_main(optimizer); | ||
|
|
||
| return LRScheduler::Create<lr_schedulers::SequentialLR>( | ||
| optimizer, std::vector<std::shared_ptr<LRScheduler>>{warmup_scheduler, main_scheduler}, | ||
| std::vector<int64_t>{config.warmup_steps}); |
There was a problem hiding this comment.
When warmup_steps > 0, main_scheduler can be nullptr (e.g., if config.type is unsupported/"none" or sub-configs filter out to empty), but the code unconditionally passes it into SequentialLR, which will CHECK-fail on a null child. Add an explicit check here (either return only the warmup scheduler, or fail fast with a clear error) before constructing the SequentialLR.
|
|
||
| StepLR::StepLR(std::shared_ptr<Optimizer> optimizer, int64_t step_size, float gamma, int64_t last_step) | ||
| : LRScheduler(std::move(optimizer), last_step), step_size_(step_size), gamma_(gamma) { | ||
| CHECK_GT(step_size_, 0) << "StepLR: step_size must be > 0."; |
There was a problem hiding this comment.
StepLR validates step_size_ > 0 but doesn't validate gamma_. Non-positive gamma_ can produce negative/NaN learning rates via pow() and via the multiplicative chained form. Consider CHECK_GT(gamma_, 0.0f) (and possibly a sensible upper bound if desired) to match typical scheduler expectations.
| CHECK_GT(step_size_, 0) << "StepLR: step_size must be > 0."; | |
| CHECK_GT(step_size_, 0) << "StepLR: step_size must be > 0."; | |
| CHECK_GT(gamma_, 0.0f) << "StepLR: gamma must be > 0."; |
| } | ||
| } | ||
|
|
||
| void ChainedScheduler::InitialStep() {} |
There was a problem hiding this comment.
ChainedScheduler::InitialStep() is empty, so a newly created ChainedScheduler keeps last_step_ == -1 while other schedulers start at 0 after Create(). This makes LastStep() inconsistent (and the saved last_step in State() will lag children by 1). Consider setting the initial step to 0 (without advancing children) to keep LastStep() semantics consistent across schedulers.
| void ChainedScheduler::InitialStep() {} | |
| void ChainedScheduler::InitialStep() { | |
| // Ensure consistent LastStep semantics with other schedulers: | |
| // a newly created ChainedScheduler should start at step 0. | |
| last_step_ = 0; | |
| } |
No description provided.