Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

## 📖 Table of Contents
- [Groups](#-Groups)
- [Introduction](#-introduction)
- [News](#-news)
- [Installation](#%EF%B8%8F-installation)
- [Quick Start](#-quick-Start)
Expand All @@ -51,8 +50,6 @@ You can contact us and communicate with us by adding our group:
|:-------------------------:|
| <img src="https://raw.githubusercontent.com/modelscope/ms-swift/main/docs/resources/wechat/megatron.png" width="200" height="200"> |

## 📝 Introduction

## 🎉 News
- 🎉 2026.03.30: MCore-Bridge is released! Providing Megatron-Core model definitions for state-of-the-art large models and making Megatron training as simple as Transformers.

Expand Down Expand Up @@ -80,6 +77,8 @@ uv pip install -e . --torch-backend=auto

## 🚀 Quick Start

How to use MCore-Bridge for training can be referred to the [ms-swift project](https://swift.readthedocs.io/en/latest/Megatron-SWIFT/Mcore-Bridge.html). Here we introduce how to use MCore-Bridge programmatically.

You need to create the following file (test.py), then run `CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 test.py`. Below is sample code demonstrating how to use Mcore-Bridge for model creation, weight loading, export, and saving.

The saved model can be used for inference by referring to the [example code in the model card](https://modelscope.cn/models/Qwen/Qwen3.5-35B-A3B).
Expand Down
5 changes: 2 additions & 3 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

## 📖 目录
- [用户群](#-用户群)
- [简介](#-简介)
- [新闻](#-新闻)
- [安装](#%EF%B8%8F-安装)
- [快速开始](#-快速开始)
Expand All @@ -50,8 +49,6 @@
|:-------------------------:|
| <img src="https://raw.githubusercontent.com/modelscope/ms-swift/main/docs/resources/wechat/megatron.png" width="200" height="200"> |

## 📝 简介

## 🎉 新闻
- 🎉 2026.03.30: MCore-Bridge 正式发布!为最先进的大模型提供 Megatron-Core 模型定义,让 Megatron 训练像 Transformers 一样简单。

Expand Down Expand Up @@ -79,6 +76,8 @@ uv pip install -e . --torch-backend=auto

## 🚀 快速开始

如何使用MCore-Bridge进行训练可以参考[ms-swift项目](https://swift.readthedocs.io/zh-cn/latest/Megatron-SWIFT/Mcore-Bridge.html)。这里介绍如何使用代码方式使用Mcore-Bridge。

你需要创建以下文件(test.py),然后运行`CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 test.py`。以下为使用Mcore-Bridge进行创建模型、权重加载、导出、保存的示例代码。

保存的模型,可以参考[模型卡片的示例代码](https://modelscope.cn/models/Qwen/Qwen3.5-35B-A3B)进行推理。
Expand Down
12 changes: 12 additions & 0 deletions src/mcore_bridge/config/model_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import copy
import os
import re
import torch.nn.functional as F
Expand Down Expand Up @@ -346,3 +347,14 @@ def _check_npu(self):
f'expert_model_parallel_size={expert_model_parallel_size}. '
f'Please set expert_model_parallel_size (EP) to {required_ep} '
f'(num_experts / {MAX_NPU_EXPERTS_PER_EP}) or higher.')

def __deepcopy__(self, memo):
cls = self.__class__
new_obj = cls.__new__(cls)
memo[id(self)] = new_obj
for k, v in self.__dict__.items():
if k == 'bridge':
setattr(new_obj, k, v)
else:
setattr(new_obj, k, copy.deepcopy(v, memo))
return new_obj
92 changes: 46 additions & 46 deletions src/mcore_bridge/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def _apply_rotary_pos_emb_bshd(
rotary_interleaved: bool = False,
multi_latent_attention: bool = False, # not use
mscale: float = 1.0,
**kwargs,
) -> torch.Tensor:
"""Apply rotary positional embedding to input tensor T.

Expand Down Expand Up @@ -390,6 +391,8 @@ def _postprocess(
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
if self.config.is_multimodal and self.config.context_parallel_size > 1:
input_ids = split_cp_inputs(input_ids, getattr(packed_seq_params, 'cu_seqlens_q', None), 1)

if self.mtp_process:
hidden_states = self.mtp(
Expand All @@ -406,55 +409,52 @@ def _postprocess(
embedding=self.embedding,
**(extra_block_kwargs or {}),
)
mtp_labels = labels.clone()
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
hidden_states = hidden_states_list[0]

if labels is not None:
mtp_labels = labels.clone()
if loss_mask is None:
# if loss_mask is not provided, use all ones as loss_mask
if packed_seq_params is None:
loss_mask = torch.ones_like(mtp_labels)
else:
loss_mask = mtp_labels.new_ones((1, packed_seq_params.cu_seqlens_q[-1]))
cu_seqlens = packed_seq_params.cu_seqlens_q if packed_seq_params is not None else None
for mtp_layer_number in range(self.config.mtp_num_layers):
# output
mtp_logits, _ = self.output_layer(
hidden_states_list[mtp_layer_number + 1],
weight=output_weight,
runtime_gather_output=runtime_gather_output,
if loss_mask is None:
# if loss_mask is not provided, use all ones as loss_mask
loss_mask = torch.ones_like(mtp_labels)
for mtp_layer_number in range(self.config.mtp_num_layers):
# output
mtp_logits, _ = self.output_layer(
hidden_states_list[mtp_layer_number + 1],
weight=output_weight,
runtime_gather_output=runtime_gather_output,
)
# Calc loss for the current Multi-Token Prediction (MTP) layers.
mtp_labels, _ = roll_tensor(
mtp_labels,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
loss_mask, _ = roll_tensor(
loss_mask,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)
loss_mask_ = (loss_mask & (mtp_labels != -100))
num_tokens = loss_mask_.sum()
mtp_loss = loss_mask_ * mtp_loss
if self.training:
mtp_loss_for_log = (
torch.sum(mtp_loss) / num_tokens if num_tokens > 0 else mtp_loss.new_tensor(0.0))
MTPLossLoggingHelper.save_loss_to_tracker(
mtp_loss_for_log,
mtp_layer_number,
self.config.mtp_num_layers,
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
)
# Calc loss for the current Multi-Token Prediction (MTP) layers.
mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group)
if cu_seqlens is None:
loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group)
loss_mask_ = loss_mask
else:
loss_mask[:, cu_seqlens[:-1]] = 0
loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1)
if self.config.context_parallel_size > 1:
loss_mask_ = split_cp_inputs(loss_mask, cu_seqlens, dim=1)
else:
loss_mask_ = loss_mask.clone()
mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)
loss_mask_ = loss_mask_ & (mtp_labels != -100)
mtp_loss = loss_mask_ * mtp_loss
num_tokens = loss_mask_.sum()
if self.training:
mtp_loss_for_log = (
torch.sum(mtp_loss) / num_tokens if num_tokens > 0 else mtp_loss.new_tensor(0.0))
MTPLossLoggingHelper.save_loss_to_tracker(
mtp_loss_for_log,
mtp_layer_number,
self.config.mtp_num_layers,
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
)
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
if self.config.calculate_per_token_loss:
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
else:
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
if self.config.calculate_per_token_loss:
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
else:
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)
sequence_parallel_override = False
if in_inference_mode and inference_context.materialize_only_last_token_logits:
if inference_context.is_static_batching():
Expand Down
5 changes: 4 additions & 1 deletion src/mcore_bridge/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,7 @@ def _apply_rotary_pos_emb_thd(
multi_latent_attention: bool = False,
mscale: float = 1.0,
cp_group: torch.distributed.ProcessGroup = None,
**kwargs,
) -> torch.Tensor:
"""A baseline implementation of applying RoPE for `thd` format.

Expand All @@ -629,7 +630,8 @@ def _apply_rotary_pos_emb_thd(
use_batched_rope = (freqs.dim() >= 1 and freqs.shape[0] == cu_seqlens_for_batched[-1]).item()
if not use_batched_rope:
logger.warning_once('Using non-batched RoPE, which may affect performance.')
kwargs = {'cp_group': cp_group} if mcore_013 else {}
if mcore_013:
kwargs['cp_group'] = cp_group
return _origin_apply_rotary_pos_emb_thd(
t,
cu_seqlens,
Expand All @@ -646,6 +648,7 @@ def _apply_rotary_pos_emb_thd(
rotary_interleaved=rotary_interleaved,
multi_latent_attention=multi_latent_attention,
mscale=mscale,
**kwargs,
).squeeze(1)

rope_utils._apply_rotary_pos_emb_thd = _apply_rotary_pos_emb_thd
Expand Down
2 changes: 1 addition & 1 deletion src/mcore_bridge/version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Make sure to modify __release_datetime__ to release time when making official release.
__version__ = '1.0.1.dev0'
__version__ = '1.1.0.dev0'
# default release datetime for branches under active development is set
# to be a time far-far-away-into-the-future
__release_datetime__ = '2099-12-31 23:59:59'
Loading