Skip to content

Feature: make MORI framework agnostic and compatible with JAX-XLA#173

Open
TianDi101 wants to merge 2 commits intomainfrom
chao/mori_frmk_agnostic
Open

Feature: make MORI framework agnostic and compatible with JAX-XLA#173
TianDi101 wants to merge 2 commits intomainfrom
chao/mori_frmk_agnostic

Conversation

@TianDi101
Copy link
Collaborator

No description provided.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a Torch-optional (“framework agnostic”) build by splitting Torch-dependent components (bootstrap + pybind entrypoints) from the core shmem/IO functionality, and gating Torch-only codepaths behind MORI_ENABLE_TORCH.

Changes:

  • Add a host-only shmem header (shmem_host_api.hpp) and wrap Torch-specific shmem init behind MORI_ENABLE_TORCH.
  • Split application and pybind targets into core vs Torch-specific libraries, updating CMake link dependencies accordingly.
  • Update Python import loader to prefer Torch-enabled pybinds when present, otherwise fall back to core-only pybinds.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
CMakeLists.txt Adds MORI_ENABLE_TORCH option and compile definition.
src/application/CMakeLists.txt Splits application into mori_application_core and optional mori_application_torch, plus a compatibility target.
src/shmem/init.cpp Wraps ShmemTorchProcessGroupInit in #ifdef MORI_ENABLE_TORCH.
src/shmem/CMakeLists.txt Links against mori_application_core and conditionally mori_application_torch; enables PIC.
include/mori/application/bootstrap/bootstrap.hpp Guards Torch bootstrap include with MORI_ENABLE_TORCH.
include/mori/shmem/shmem_host_api.hpp New host-only shmem API header for framework-agnostic consumers.
include/mori/shmem/shmem_api.hpp Refactors to include host-only APIs and conditionally declare Torch init API.
src/pybind/CMakeLists.txt Adds core pybind target and makes Torch pybind target conditional.
src/pybind/mori.cpp Delegates shmem/IO bindings to the new core binding registrars; keeps Torch-specific init.
src/pybind/mori_core.cpp / src/pybind/mori_core.hpp / src/pybind/pybind_core.cpp New framework-agnostic pybind module exposing shmem + IO bindings without Torch.
python/mori/cpp/__init__.py Loads Torch pybinds if available; otherwise loads core-only pybinds.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +26 to +70
#include "src/pybind/mori_core.hpp"

#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "mori/io/io.hpp"
#include "mori/shmem/shmem_host_api.hpp"

namespace py = pybind11;

/* ---------------------------------------------------------------------------------------------- */
/* Shmem wrapper functions */
/* ---------------------------------------------------------------------------------------------- */
namespace {

int64_t ShmemFinalize() { return mori::shmem::ShmemFinalize(); }

int64_t ShmemModuleInit(uint64_t hipModule) {
return mori::shmem::ShmemModuleInit(reinterpret_cast<void*>(hipModule));
}

int64_t ShmemMyPe() { return mori::shmem::ShmemMyPe(); }

int64_t ShmemNPes() { return mori::shmem::ShmemNPes(); }

// UniqueId-based initialization APIs
py::bytes ShmemGetUniqueId() {
mori::shmem::mori_shmem_uniqueid_t uid;
mori::shmem::ShmemGetUniqueId(&uid);
return py::bytes(reinterpret_cast<const char*>(uid.data()), uid.size());
}

int64_t ShmemInitAttr(unsigned int flags, int32_t rank, int32_t nranks,
const py::bytes& uid_bytes) {
mori::shmem::mori_shmem_init_attr_t attr;
mori::shmem::mori_shmem_uniqueid_t uid;

// Convert Python bytes to uniqueid
Py_ssize_t len = PyBytes_Size(uid_bytes.ptr());
const char* data = PyBytes_AsString(uid_bytes.ptr());
if (len != MORI_SHMEM_UNIQUE_ID_BYTES) {
throw std::runtime_error("Invalid unique ID size");
}
std::memcpy(uid.data(), data, MORI_SHMEM_UNIQUE_ID_BYTES);
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file uses std::memcpy and std::runtime_error but does not include the standard headers that declare them (<cstring> and <stdexcept>). Relying on transitive includes is brittle and can fail on different toolchains/standard library implementations; please add the explicit includes.

Copilot uses AI. Check for mistakes.
@TianDi101
Copy link
Collaborator Author

Hi @i-chaochen , thanks for the work, some questions:

  • How does JAX calls custom library? For Torch, we use pybind + torch module, I'm wondering whether JAX shares the same path?
  • Ever considered using TVM-FFI? It is a cross-framework / cross-language FFI standard which seems suits this scenario.

@i-chaochen
Copy link

i-chaochen commented Feb 24, 2026

Thanks @TianDi101 for bring it as the PR!

Hi @i-chaochen , thanks for the work, some questions:

  • How does JAX calls custom library? For Torch, we use pybind + torch module, I'm wondering whether JAX shares the same path?

in JAX, it uses FFI https://docs.jax.dev/en/latest/ffi.html#foreign-function-interface-ffi to have front code and gpu code, you can see our simple use code here // jaxpp is just another wrapper atop of jax.

Back to your question, for jax, yes, we're using pybind11 + xla module, and you can see our initial POC branch b2a2bdb#diff-fe7afb5c9c916e521401d3fcfb4277d5071798c3baf83baf11d6071742823584

  • Ever considered using TVM-FFI? It is a cross-framework / cross-language FFI standard which seems suits this scenario.

Yes, it's FFI as I mentioned in the above, but I'm not familar with TVM-FFI, not sure whether is same stuff and compatible with other frameworks. Essentially, it's using "pjrt plugin" which is a set of C APIs (within XLA) to define a binary interface.

As a side note, we have rebased and simplified this separation here https://github.com/ROCm/mori/commits/p_cleanup_torch_mpi_deps/ Hope this is easier to accept.

  • the separation between host and device code is still not moved here, we still need it.

cc @pemeliya

@pemeliya pemeliya force-pushed the chao/mori_frmk_agnostic branch 3 times, most recently from f5cae4c to 8daa59a Compare February 27, 2026 14:46
@pemeliya pemeliya requested a review from Copilot February 27, 2026 14:57
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 20 out of 22 changed files in this pull request and generated 5 comments.

Comments suppressed due to low confidence (1)

src/shmem/init.cpp:804

  • ShmemInitAttr(...) returns -1 for MORI_SHMEM_INIT_WITH_MPI_COMM when MPI support is compiled out (the MPI-handling branch is removed by #ifdef ENABLE_MPI), but no error is logged, which makes misconfiguration hard to diagnose. Consider adding an #else branch (or adjusting validation) to emit a clear "MPI support disabled" error when flags == MORI_SHMEM_INIT_WITH_MPI_COMM and ENABLE_MPI is not set.
#ifdef ENABLE_MPI
  // MPI-based initialization
  if (flags == MORI_SHMEM_INIT_WITH_MPI_COMM) {
    if (attr->mpi_comm == nullptr) {
      MORI_SHMEM_ERROR("MPI_Comm is null");

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +48 to +59
add_library(mori_core_pybinds SHARED mori_core.cpp pybind_core.cpp)

target_include_directories(
mori_pybinds PUBLIC ${PYTHON_INCLUDE_DIRS} ${TORCH_INCLUDE_DIRS}
${CMAKE_BINARY_DIR}/generated/include)
target_link_directories(mori_pybinds PUBLIC ${TORCH_INSTALL_PREFIX}/lib)
target_link_libraries(
mori_pybinds
mori_ops
mori_io
${TORCH_LIBRARIES}
torch_python
hip::host
hip::device)

# For python packages to find dependent libraries
set_target_properties(
mori_pybinds
PROPERTIES BUILD_RPATH "$ORIGIN;$ORIGIN/../torch/lib"
mori_core_pybinds PUBLIC
${PYTHON_INCLUDE_DIRS} ${PYBIND11_INCLUDE_DIR}
${CMAKE_BINARY_DIR}/generated/include)
target_link_libraries(
mori_core_pybinds
mori_ops
mori_io
hip::host
hip::device)
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This CMake adds a new mori_core_pybinds shared library, but the existing packaging/loader flow appears to only copy/load libmori_pybinds.so (e.g., setup.py and python/mori/cpp/__init__.py). Ensure libmori_core_pybinds.so is also built/installed/copied and loaded at runtime, or consolidate the bindings to avoid shipping an unusable module.

Copilot uses AI. Check for mistakes.
Comment on lines +185 to +189
m.def("shmem_torch_process_group_init", &ShmemTorchProcessGroupInit);
m.def("shmem_finalize", &ShmemFinalize);
m.def("shmem_mype", &ShmemMyPe);
m.def("shmem_npes", &ShmemNPes);
m.def("shmem_num_qp_per_pe", &ShmemNumQpPerPe);
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RegisterMoriShmem defines several m.def(...) bindings twice (e.g., shmem_torch_process_group_init, shmem_finalize, shmem_mype, shmem_npes). This creates redundant overload entries and can lead to confusing docstrings/overload resolution. Keep a single definition per exported name (with args + docstring) unless you are intentionally providing distinct overloads.

Suggested change
m.def("shmem_torch_process_group_init", &ShmemTorchProcessGroupInit);
m.def("shmem_finalize", &ShmemFinalize);
m.def("shmem_mype", &ShmemMyPe);
m.def("shmem_npes", &ShmemNPes);
m.def("shmem_num_qp_per_pe", &ShmemNumQpPerPe);

Copilot uses AI. Check for mistakes.
@@ -23,6 +23,4 @@

PYBIND11_MODULE(libmori_pybinds, m) {
mori::RegisterMoriOps(m);
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

libmori_pybinds no longer registers Shmem/IO symbols, but the current Python loader imports and re-exports from libmori_pybinds (see python/mori/cpp/__init__.py). This is a breaking change unless the Python package is updated to also load libmori_core_pybinds (or libmori_pybinds continues to re-export these registrations).

Suggested change
mori::RegisterMoriOps(m);
mori::RegisterMoriOps(m);
// Import core pybinds and re-export their symbols to maintain backward compatibility.
auto core = pybind11::module_::import("libmori_core_pybinds");
m.attr("__dict__").attr("update")(core.attr("__dict__"));

Copilot uses AI. Check for mistakes.

#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mori_core.cpp uses std::memcpy and throws std::runtime_error, but does not include <cstring> / <stdexcept>. Relying on transitive includes makes the file fragile; include the standard headers needed by this translation unit explicitly.

Suggested change
#include <pybind11/stl.h>
#include <pybind11/stl.h>
#include <cstring>
#include <stdexcept>

Copilot uses AI. Check for mistakes.
@jhchouuu
Copy link
Collaborator

jhchouuu commented Mar 2, 2026

Hi @i-chaochen and @pemeliya , I've been working on a similar direction on the jiahzhou/refactor_framework branch — making the pybind layer fully torch-free. Here's a summary of what I've done that might overlap or help with this PR:
https://github.com/ROCm/mori/tree/jiahzhou/refactor_framework

1. pybind_ops.cpp: torch dependency fully removed

All torch::Tensor parameters replaced with raw int64_t pointers, torch::from_blob replaced with returning raw pointers, TORCH_CHECK moved to Python, getCurrentHIPStream passed from Python as int64_t. The C++ pybind layer now has zero torch includes — no torch/python.h, no ATen/hip/HIPContext.h, no torch_python link dependency.

2. Python-side tensor conversion via DLPack + __cuda_array_interface__

A new tensor_utils.py provides GpuTensorView — a lightweight GPU array descriptor implementing both __dlpack__/__dlpack_device__ and __cuda_array_interface__. Any framework can consume it zero-copy:

view = gpu_tensor_view(ptr, shape, dtype)
t = torch.as_tensor(view)           # PyTorch
a = jax.dlpack.from_dlpack(view)    # ? JAX

https://github.com/ROCm/mori/blob/jiahzhou/refactor_framework/python/mori/tensor_utils.py
Supports float32, bfloat16, float8, float4_e2m1fn_x2, int32.

I haven't been able to test the JAX path end-to-end — I don't have a jax[rocm] environment set up. The DLPack capsule is created correctly (__dlpack_device__ returns the proper ROCm device type), but JAX consumption is untested. Would appreciate guidance on how to set up a jax[rocm] test environment if you have one available.

These changes are on the jiahzhou/refactor_framework branch and still a work in progress. I've also looked at the p_cleanup_torch_mpi_deps branch — there's significant overlap in goals but different approaches (e.g., #ifdef ENABLE_TORCH vs fully removing torch from the pybind API). Happy to discuss whether these conflict or can complement each other.

cc @TianDi101

@i-chaochen i-chaochen force-pushed the chao/mori_frmk_agnostic branch 2 times, most recently from 83faff3 to 297de1f Compare March 3, 2026 01:54
Use conditional compilation in a single mori_application library (no
separate mori_application_core.so / mori_application_torch.so split):
  - torch_bootstrap.cpp added via target_sources when ENABLE_TORCH=ON
  - mpi_bootstrap.cpp added via target_sources when ENABLE_MPI=ON

Pybind builds one of two libraries depending on the enabled framework:
  - libmori_pybinds.so when ENABLE_TORCH=ON (existing PyTorch path)
  - libmori_xla_ffi_ops.so when ENABLE_XLA_FFI=ON (JAX/XLA path)
  No separate libmori_core_pybinds.so target.

No shmem_host_api.hpp split -- mori already separates host and device
code, so the existing shmem_api.hpp is used as-is by both frameworks.

Key changes:
- Add ENABLE_TORCH, ENABLE_MPI cmake options (default ON) with
  compile definitions
- Guard torch_bootstrap and MPI includes behind their respective
  #ifdef guards
- Add POSITION_INDEPENDENT_CODE to mori_shmem for shared lib linking
- Create mori_core.cpp / pybind_core.cpp with shmem + IO bindings
- Refactor mori.cpp to delegate core registration to mori_core.cpp
- Update python/mori/cpp/__init__.py for library auto-detection

Existing PyTorch users are unaffected: ENABLE_TORCH defaults to ON and
all existing APIs/symbols remain available in libmori_pybinds.so.
Vendor XLA FFI headers from rocm-xla/rocm-jaxlib-v0.9.0 (ca6c4f848f)
and implement MoE dispatch/combine FFI handlers so that JAX can invoke
mori kernels via XLA custom calls without any Torch dependency.

New ENABLE_XLA_FFI cmake option (default OFF) gates the entire FFI
build path. When enabled, builds libmori_xla_ffi_ops.so exporting:
  - mori_ep_dispatch, mori_ep_combine        (MoE forward ops)
  - mori_ep_dispatch_recv, mori_ep_combine_recv (async recv ops)
  - mori_ep_reset                            (barrier reset)
  - mori_ffi_create_handle, mori_ffi_destroy_handle (lifecycle)

Key additions:
- 3rdparty/xla_ffi/: vendored c_api.h, api.h, ffi.h (header-only)
- src/ffi/mori_xla_ffi_ops.cpp: FFI handler implementations using
  XLA_FFI_DEFINE_HANDLER_SYMBOL to export C symbols from the .so
- src/ffi/mori_xla_ffi_handle_mgr.hpp: thread-safe singleton managing
  EpDispatchCombineHandle instances by integer ID
- python/mori/jax/: Python package with _ffi_registry (ctypes loader +
  JAX registration), ops.py (EpDispatchCombineOp context manager)
- tests/cpp/ffi/test_ffi_contract.cpp: compile-time + link-time
  contract test for FFI surface stability
- tests/python/test_shim_parity.py: runtime check that torch pybind
  and XLA FFI shims cover the same set of logical MoE operations
- .github/CODEOWNERS: require ops team review on FFI/pybind surface

Build fix: add POSITION_INDEPENDENT_CODE to mori_ops (required for
linking the static .a into libmori_xla_ffi_ops.so).

Verified in chao-mori-dev container across 3 configs:
  - ENABLE_TORCH=OFF, ENABLE_XLA_FFI=ON  (XLA-only)
  - ENABLE_TORCH=ON,  ENABLE_XLA_FFI=OFF (Torch-only regression)
  - ENABLE_TORCH=ON,  ENABLE_XLA_FFI=ON  (dual)
@i-chaochen i-chaochen force-pushed the chao/mori_frmk_agnostic branch from 297de1f to aa9eaf8 Compare March 3, 2026 01:57
@i-chaochen
Copy link

i-chaochen commented Mar 3, 2026

Hi @jhchouuu thanks a lot for your work!

I notice you already adopted the changes about host/device code, so you can see we rm that part and my latest rebase doesn't need mori_application_core.so or mori_application_torch.so, we can just separate it when ENABLE_TORCH=ON

I would leave the 1st one to @pemeliya

about 2nd one on DLPack + __cuda_array_interface__

Yes, DLPack works on our jax side with __cuda_array_interface__ jax-ml/jax@63b94d5 this is a valid move on us as well. Thanks!

In addition, we have added jax/xla FFI header file to let jax xla support at our 2nd commit.

Regarding overlapping with your branch, as long as it can land on ToT of mori, and it works on jax-xla side, we don't mind how would you or your team like to proceed :)

If your changes can merge more quickly and we can discard overlapped part and rebase our work. The most important to us is to be able to use ToT of mori and xla-ffi header files. (2nd commit aa9eaf8)

@i-chaochen i-chaochen changed the title Feature: make MORI framework agnostic Feature: make MORI framework agnostic and compatible with JAX-XLA Mar 3, 2026
@pemeliya
Copy link

pemeliya commented Mar 3, 2026

Hi @jhchouuu, I agree with Chao, your approach of removing torch deps completely from C++ part is cleaner (vs just adding ifdefs). We did it this way in the first place because did not want to introduce breaking changes on the python side. So, we can wait until your changes are merged.

What concerns cuda_array_interface, this really does not matter for us since XLA/JAX custom calls work in a completely different way. Later on, probably we will have to extend/modify the underlying dispatch/combine calls to be able to work with symmetric-allocated output buffers. Note that, in XLA there is no any analogue of Torch::from_blob function since XLA allocates input and output buffers upfront.

@jhchouuu
Copy link
Collaborator

jhchouuu commented Mar 3, 2026

Hi @i-chaochen and @pemeliya , thanks for your suggestion.!

Now I am focus on this branch, however, since the refactoring scope is relatively large, we will merge it only after extensive testing. This will probably take a few days, and we expect to complete it by Friday (2026.3.6). In the meantime, we will also spend some time learning JAX to ensure that our changes won't cause major disruptions for you. By the way, the main guiding principles of our refactoring are: 1) removing Torch dependencies completely from the C++ part, I think we will no longer use methods like Torch:from_blob either, and 2) separating host-side and device-side code as much as possible, adding a JIT framework where device-side code is JIT-compiled based on different NIC and GPU arch.

jhchouuu added a commit that referenced this pull request Mar 6, 2026
Based on PR #173 by Chao Chen <cchen104@amd.com>, adapted for the
refactored architecture (raw pointer args + Python-side kernel launch).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants