diff --git a/README.md b/README.md
index 88a5572..ff712e2 100644
--- a/README.md
+++ b/README.md
@@ -35,7 +35,6 @@
## 📖 Table of Contents
- [Groups](#-Groups)
-- [Introduction](#-introduction)
- [News](#-news)
- [Installation](#%EF%B8%8F-installation)
- [Quick Start](#-quick-Start)
@@ -51,8 +50,6 @@ You can contact us and communicate with us by adding our group:
|:-------------------------:|
|
|
-## 📝 Introduction
-
## 🎉 News
- 🎉 2026.03.30: MCore-Bridge is released! Providing Megatron-Core model definitions for state-of-the-art large models and making Megatron training as simple as Transformers.
@@ -80,6 +77,8 @@ uv pip install -e . --torch-backend=auto
## 🚀 Quick Start
+How to use MCore-Bridge for training can be referred to the [ms-swift project](https://swift.readthedocs.io/en/latest/Megatron-SWIFT/Mcore-Bridge.html). Here we introduce how to use MCore-Bridge programmatically.
+
You need to create the following file (test.py), then run `CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 test.py`. Below is sample code demonstrating how to use Mcore-Bridge for model creation, weight loading, export, and saving.
The saved model can be used for inference by referring to the [example code in the model card](https://modelscope.cn/models/Qwen/Qwen3.5-35B-A3B).
diff --git a/README_zh.md b/README_zh.md
index 90c9b17..c502ad4 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -35,7 +35,6 @@
## 📖 目录
- [用户群](#-用户群)
-- [简介](#-简介)
- [新闻](#-新闻)
- [安装](#%EF%B8%8F-安装)
- [快速开始](#-快速开始)
@@ -50,8 +49,6 @@
|:-------------------------:|
|
|
-## 📝 简介
-
## 🎉 新闻
- 🎉 2026.03.30: MCore-Bridge 正式发布!为最先进的大模型提供 Megatron-Core 模型定义,让 Megatron 训练像 Transformers 一样简单。
@@ -79,6 +76,8 @@ uv pip install -e . --torch-backend=auto
## 🚀 快速开始
+如何使用MCore-Bridge进行训练可以参考[ms-swift项目](https://swift.readthedocs.io/zh-cn/latest/Megatron-SWIFT/Mcore-Bridge.html)。这里介绍如何使用代码方式使用Mcore-Bridge。
+
你需要创建以下文件(test.py),然后运行`CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 test.py`。以下为使用Mcore-Bridge进行创建模型、权重加载、导出、保存的示例代码。
保存的模型,可以参考[模型卡片的示例代码](https://modelscope.cn/models/Qwen/Qwen3.5-35B-A3B)进行推理。
diff --git a/src/mcore_bridge/config/model_config.py b/src/mcore_bridge/config/model_config.py
index 48afc46..ca21567 100644
--- a/src/mcore_bridge/config/model_config.py
+++ b/src/mcore_bridge/config/model_config.py
@@ -1,4 +1,5 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
+import copy
import os
import re
import torch.nn.functional as F
@@ -346,3 +347,14 @@ def _check_npu(self):
f'expert_model_parallel_size={expert_model_parallel_size}. '
f'Please set expert_model_parallel_size (EP) to {required_ep} '
f'(num_experts / {MAX_NPU_EXPERTS_PER_EP}) or higher.')
+
+ def __deepcopy__(self, memo):
+ cls = self.__class__
+ new_obj = cls.__new__(cls)
+ memo[id(self)] = new_obj
+ for k, v in self.__dict__.items():
+ if k == 'bridge':
+ setattr(new_obj, k, v)
+ else:
+ setattr(new_obj, k, copy.deepcopy(v, memo))
+ return new_obj
diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py
index 29db5d0..6df275c 100644
--- a/src/mcore_bridge/model/gpt_model.py
+++ b/src/mcore_bridge/model/gpt_model.py
@@ -158,6 +158,7 @@ def _apply_rotary_pos_emb_bshd(
rotary_interleaved: bool = False,
multi_latent_attention: bool = False, # not use
mscale: float = 1.0,
+ **kwargs,
) -> torch.Tensor:
"""Apply rotary positional embedding to input tensor T.
@@ -390,6 +391,8 @@ def _postprocess(
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
+ if self.config.is_multimodal and self.config.context_parallel_size > 1:
+ input_ids = split_cp_inputs(input_ids, getattr(packed_seq_params, 'cu_seqlens_q', None), 1)
if self.mtp_process:
hidden_states = self.mtp(
@@ -406,55 +409,52 @@ def _postprocess(
embedding=self.embedding,
**(extra_block_kwargs or {}),
)
+ mtp_labels = labels.clone()
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
hidden_states = hidden_states_list[0]
-
- if labels is not None:
- mtp_labels = labels.clone()
- if loss_mask is None:
- # if loss_mask is not provided, use all ones as loss_mask
- if packed_seq_params is None:
- loss_mask = torch.ones_like(mtp_labels)
- else:
- loss_mask = mtp_labels.new_ones((1, packed_seq_params.cu_seqlens_q[-1]))
- cu_seqlens = packed_seq_params.cu_seqlens_q if packed_seq_params is not None else None
- for mtp_layer_number in range(self.config.mtp_num_layers):
- # output
- mtp_logits, _ = self.output_layer(
- hidden_states_list[mtp_layer_number + 1],
- weight=output_weight,
- runtime_gather_output=runtime_gather_output,
+ if loss_mask is None:
+ # if loss_mask is not provided, use all ones as loss_mask
+ loss_mask = torch.ones_like(mtp_labels)
+ for mtp_layer_number in range(self.config.mtp_num_layers):
+ # output
+ mtp_logits, _ = self.output_layer(
+ hidden_states_list[mtp_layer_number + 1],
+ weight=output_weight,
+ runtime_gather_output=runtime_gather_output,
+ )
+ # Calc loss for the current Multi-Token Prediction (MTP) layers.
+ mtp_labels, _ = roll_tensor(
+ mtp_labels,
+ shifts=-1,
+ dims=-1,
+ cp_group=self.cp_group,
+ packed_seq_params=packed_seq_params,
+ )
+ loss_mask, _ = roll_tensor(
+ loss_mask,
+ shifts=-1,
+ dims=-1,
+ cp_group=self.cp_group,
+ packed_seq_params=packed_seq_params,
+ )
+ mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)
+ loss_mask_ = (loss_mask & (mtp_labels != -100))
+ num_tokens = loss_mask_.sum()
+ mtp_loss = loss_mask_ * mtp_loss
+ if self.training:
+ mtp_loss_for_log = (
+ torch.sum(mtp_loss) / num_tokens if num_tokens > 0 else mtp_loss.new_tensor(0.0))
+ MTPLossLoggingHelper.save_loss_to_tracker(
+ mtp_loss_for_log,
+ mtp_layer_number,
+ self.config.mtp_num_layers,
+ avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
)
- # Calc loss for the current Multi-Token Prediction (MTP) layers.
- mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group)
- if cu_seqlens is None:
- loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group)
- loss_mask_ = loss_mask
- else:
- loss_mask[:, cu_seqlens[:-1]] = 0
- loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1)
- if self.config.context_parallel_size > 1:
- loss_mask_ = split_cp_inputs(loss_mask, cu_seqlens, dim=1)
- else:
- loss_mask_ = loss_mask.clone()
- mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)
- loss_mask_ = loss_mask_ & (mtp_labels != -100)
- mtp_loss = loss_mask_ * mtp_loss
- num_tokens = loss_mask_.sum()
- if self.training:
- mtp_loss_for_log = (
- torch.sum(mtp_loss) / num_tokens if num_tokens > 0 else mtp_loss.new_tensor(0.0))
- MTPLossLoggingHelper.save_loss_to_tracker(
- mtp_loss_for_log,
- mtp_layer_number,
- self.config.mtp_num_layers,
- avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
- )
- mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
- if self.config.calculate_per_token_loss:
- hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
- else:
- hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)
+ mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
+ if self.config.calculate_per_token_loss:
+ hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
+ else:
+ hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)
sequence_parallel_override = False
if in_inference_mode and inference_context.materialize_only_last_token_logits:
if inference_context.is_static_batching():
diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py
index 2de8e51..392ff34 100644
--- a/src/mcore_bridge/patcher.py
+++ b/src/mcore_bridge/patcher.py
@@ -608,6 +608,7 @@ def _apply_rotary_pos_emb_thd(
multi_latent_attention: bool = False,
mscale: float = 1.0,
cp_group: torch.distributed.ProcessGroup = None,
+ **kwargs,
) -> torch.Tensor:
"""A baseline implementation of applying RoPE for `thd` format.
@@ -629,7 +630,8 @@ def _apply_rotary_pos_emb_thd(
use_batched_rope = (freqs.dim() >= 1 and freqs.shape[0] == cu_seqlens_for_batched[-1]).item()
if not use_batched_rope:
logger.warning_once('Using non-batched RoPE, which may affect performance.')
- kwargs = {'cp_group': cp_group} if mcore_013 else {}
+ if mcore_013:
+ kwargs['cp_group'] = cp_group
return _origin_apply_rotary_pos_emb_thd(
t,
cu_seqlens,
@@ -646,6 +648,7 @@ def _apply_rotary_pos_emb_thd(
rotary_interleaved=rotary_interleaved,
multi_latent_attention=multi_latent_attention,
mscale=mscale,
+ **kwargs,
).squeeze(1)
rope_utils._apply_rotary_pos_emb_thd = _apply_rotary_pos_emb_thd
diff --git a/src/mcore_bridge/version.py b/src/mcore_bridge/version.py
index 3594f30..f4f1f4b 100644
--- a/src/mcore_bridge/version.py
+++ b/src/mcore_bridge/version.py
@@ -1,5 +1,5 @@
# Make sure to modify __release_datetime__ to release time when making official release.
-__version__ = '1.0.1.dev0'
+__version__ = '1.1.0.dev0'
# default release datetime for branches under active development is set
# to be a time far-far-away-into-the-future
__release_datetime__ = '2099-12-31 23:59:59'