Feature: make MORI framework agnostic and compatible with JAX-XLA#173
Feature: make MORI framework agnostic and compatible with JAX-XLA#173
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces a Torch-optional (“framework agnostic”) build by splitting Torch-dependent components (bootstrap + pybind entrypoints) from the core shmem/IO functionality, and gating Torch-only codepaths behind MORI_ENABLE_TORCH.
Changes:
- Add a host-only shmem header (
shmem_host_api.hpp) and wrap Torch-specific shmem init behindMORI_ENABLE_TORCH. - Split application and pybind targets into core vs Torch-specific libraries, updating CMake link dependencies accordingly.
- Update Python import loader to prefer Torch-enabled pybinds when present, otherwise fall back to core-only pybinds.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
CMakeLists.txt |
Adds MORI_ENABLE_TORCH option and compile definition. |
src/application/CMakeLists.txt |
Splits application into mori_application_core and optional mori_application_torch, plus a compatibility target. |
src/shmem/init.cpp |
Wraps ShmemTorchProcessGroupInit in #ifdef MORI_ENABLE_TORCH. |
src/shmem/CMakeLists.txt |
Links against mori_application_core and conditionally mori_application_torch; enables PIC. |
include/mori/application/bootstrap/bootstrap.hpp |
Guards Torch bootstrap include with MORI_ENABLE_TORCH. |
include/mori/shmem/shmem_host_api.hpp |
New host-only shmem API header for framework-agnostic consumers. |
include/mori/shmem/shmem_api.hpp |
Refactors to include host-only APIs and conditionally declare Torch init API. |
src/pybind/CMakeLists.txt |
Adds core pybind target and makes Torch pybind target conditional. |
src/pybind/mori.cpp |
Delegates shmem/IO bindings to the new core binding registrars; keeps Torch-specific init. |
src/pybind/mori_core.cpp / src/pybind/mori_core.hpp / src/pybind/pybind_core.cpp |
New framework-agnostic pybind module exposing shmem + IO bindings without Torch. |
python/mori/cpp/__init__.py |
Loads Torch pybinds if available; otherwise loads core-only pybinds. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| #include "src/pybind/mori_core.hpp" | ||
|
|
||
| #include <pybind11/operators.h> | ||
| #include <pybind11/pybind11.h> | ||
| #include <pybind11/stl.h> | ||
|
|
||
| #include "mori/io/io.hpp" | ||
| #include "mori/shmem/shmem_host_api.hpp" | ||
|
|
||
| namespace py = pybind11; | ||
|
|
||
| /* ---------------------------------------------------------------------------------------------- */ | ||
| /* Shmem wrapper functions */ | ||
| /* ---------------------------------------------------------------------------------------------- */ | ||
| namespace { | ||
|
|
||
| int64_t ShmemFinalize() { return mori::shmem::ShmemFinalize(); } | ||
|
|
||
| int64_t ShmemModuleInit(uint64_t hipModule) { | ||
| return mori::shmem::ShmemModuleInit(reinterpret_cast<void*>(hipModule)); | ||
| } | ||
|
|
||
| int64_t ShmemMyPe() { return mori::shmem::ShmemMyPe(); } | ||
|
|
||
| int64_t ShmemNPes() { return mori::shmem::ShmemNPes(); } | ||
|
|
||
| // UniqueId-based initialization APIs | ||
| py::bytes ShmemGetUniqueId() { | ||
| mori::shmem::mori_shmem_uniqueid_t uid; | ||
| mori::shmem::ShmemGetUniqueId(&uid); | ||
| return py::bytes(reinterpret_cast<const char*>(uid.data()), uid.size()); | ||
| } | ||
|
|
||
| int64_t ShmemInitAttr(unsigned int flags, int32_t rank, int32_t nranks, | ||
| const py::bytes& uid_bytes) { | ||
| mori::shmem::mori_shmem_init_attr_t attr; | ||
| mori::shmem::mori_shmem_uniqueid_t uid; | ||
|
|
||
| // Convert Python bytes to uniqueid | ||
| Py_ssize_t len = PyBytes_Size(uid_bytes.ptr()); | ||
| const char* data = PyBytes_AsString(uid_bytes.ptr()); | ||
| if (len != MORI_SHMEM_UNIQUE_ID_BYTES) { | ||
| throw std::runtime_error("Invalid unique ID size"); | ||
| } | ||
| std::memcpy(uid.data(), data, MORI_SHMEM_UNIQUE_ID_BYTES); |
There was a problem hiding this comment.
This file uses std::memcpy and std::runtime_error but does not include the standard headers that declare them (<cstring> and <stdexcept>). Relying on transitive includes is brittle and can fail on different toolchains/standard library implementations; please add the explicit includes.
|
Hi @i-chaochen , thanks for the work, some questions:
|
|
Thanks @TianDi101 for bring it as the PR!
in JAX, it uses FFI https://docs.jax.dev/en/latest/ffi.html#foreign-function-interface-ffi to have front code and gpu code, you can see our simple use code here // jaxpp is just another wrapper atop of jax. Back to your question, for jax, yes, we're using pybind11 + xla module, and you can see our initial POC branch b2a2bdb#diff-fe7afb5c9c916e521401d3fcfb4277d5071798c3baf83baf11d6071742823584
Yes, it's FFI as I mentioned in the above, but I'm not familar with TVM-FFI, not sure whether is same stuff and compatible with other frameworks. Essentially, it's using "pjrt plugin" which is a set of C APIs (within XLA) to define a binary interface. As a side note, we have rebased and simplified this separation here https://github.com/ROCm/mori/commits/p_cleanup_torch_mpi_deps/ Hope this is easier to accept.
cc @pemeliya |
f5cae4c to
8daa59a
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 20 out of 22 changed files in this pull request and generated 5 comments.
Comments suppressed due to low confidence (1)
src/shmem/init.cpp:804
ShmemInitAttr(...)returns-1forMORI_SHMEM_INIT_WITH_MPI_COMMwhen MPI support is compiled out (the MPI-handling branch is removed by#ifdef ENABLE_MPI), but no error is logged, which makes misconfiguration hard to diagnose. Consider adding an#elsebranch (or adjusting validation) to emit a clear "MPI support disabled" error whenflags == MORI_SHMEM_INIT_WITH_MPI_COMMandENABLE_MPIis not set.
#ifdef ENABLE_MPI
// MPI-based initialization
if (flags == MORI_SHMEM_INIT_WITH_MPI_COMM) {
if (attr->mpi_comm == nullptr) {
MORI_SHMEM_ERROR("MPI_Comm is null");
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| add_library(mori_core_pybinds SHARED mori_core.cpp pybind_core.cpp) | ||
|
|
||
| target_include_directories( | ||
| mori_pybinds PUBLIC ${PYTHON_INCLUDE_DIRS} ${TORCH_INCLUDE_DIRS} | ||
| ${CMAKE_BINARY_DIR}/generated/include) | ||
| target_link_directories(mori_pybinds PUBLIC ${TORCH_INSTALL_PREFIX}/lib) | ||
| target_link_libraries( | ||
| mori_pybinds | ||
| mori_ops | ||
| mori_io | ||
| ${TORCH_LIBRARIES} | ||
| torch_python | ||
| hip::host | ||
| hip::device) | ||
|
|
||
| # For python packages to find dependent libraries | ||
| set_target_properties( | ||
| mori_pybinds | ||
| PROPERTIES BUILD_RPATH "$ORIGIN;$ORIGIN/../torch/lib" | ||
| mori_core_pybinds PUBLIC | ||
| ${PYTHON_INCLUDE_DIRS} ${PYBIND11_INCLUDE_DIR} | ||
| ${CMAKE_BINARY_DIR}/generated/include) | ||
| target_link_libraries( | ||
| mori_core_pybinds | ||
| mori_ops | ||
| mori_io | ||
| hip::host | ||
| hip::device) |
There was a problem hiding this comment.
This CMake adds a new mori_core_pybinds shared library, but the existing packaging/loader flow appears to only copy/load libmori_pybinds.so (e.g., setup.py and python/mori/cpp/__init__.py). Ensure libmori_core_pybinds.so is also built/installed/copied and loaded at runtime, or consolidate the bindings to avoid shipping an unusable module.
| m.def("shmem_torch_process_group_init", &ShmemTorchProcessGroupInit); | ||
| m.def("shmem_finalize", &ShmemFinalize); | ||
| m.def("shmem_mype", &ShmemMyPe); | ||
| m.def("shmem_npes", &ShmemNPes); | ||
| m.def("shmem_num_qp_per_pe", &ShmemNumQpPerPe); |
There was a problem hiding this comment.
RegisterMoriShmem defines several m.def(...) bindings twice (e.g., shmem_torch_process_group_init, shmem_finalize, shmem_mype, shmem_npes). This creates redundant overload entries and can lead to confusing docstrings/overload resolution. Keep a single definition per exported name (with args + docstring) unless you are intentionally providing distinct overloads.
| m.def("shmem_torch_process_group_init", &ShmemTorchProcessGroupInit); | |
| m.def("shmem_finalize", &ShmemFinalize); | |
| m.def("shmem_mype", &ShmemMyPe); | |
| m.def("shmem_npes", &ShmemNPes); | |
| m.def("shmem_num_qp_per_pe", &ShmemNumQpPerPe); |
| @@ -23,6 +23,4 @@ | |||
|
|
|||
| PYBIND11_MODULE(libmori_pybinds, m) { | |||
| mori::RegisterMoriOps(m); | |||
There was a problem hiding this comment.
libmori_pybinds no longer registers Shmem/IO symbols, but the current Python loader imports and re-exports from libmori_pybinds (see python/mori/cpp/__init__.py). This is a breaking change unless the Python package is updated to also load libmori_core_pybinds (or libmori_pybinds continues to re-export these registrations).
| mori::RegisterMoriOps(m); | |
| mori::RegisterMoriOps(m); | |
| // Import core pybinds and re-export their symbols to maintain backward compatibility. | |
| auto core = pybind11::module_::import("libmori_core_pybinds"); | |
| m.attr("__dict__").attr("update")(core.attr("__dict__")); |
|
|
||
| #include <pybind11/operators.h> | ||
| #include <pybind11/pybind11.h> | ||
| #include <pybind11/stl.h> |
There was a problem hiding this comment.
mori_core.cpp uses std::memcpy and throws std::runtime_error, but does not include <cstring> / <stdexcept>. Relying on transitive includes makes the file fragile; include the standard headers needed by this translation unit explicitly.
| #include <pybind11/stl.h> | |
| #include <pybind11/stl.h> | |
| #include <cstring> | |
| #include <stdexcept> |
|
Hi @i-chaochen and @pemeliya , I've been working on a similar direction on the 1. pybind_ops.cpp: torch dependency fully removedAll 2. Python-side tensor conversion via DLPack +
|
83faff3 to
297de1f
Compare
Use conditional compilation in a single mori_application library (no separate mori_application_core.so / mori_application_torch.so split): - torch_bootstrap.cpp added via target_sources when ENABLE_TORCH=ON - mpi_bootstrap.cpp added via target_sources when ENABLE_MPI=ON Pybind builds one of two libraries depending on the enabled framework: - libmori_pybinds.so when ENABLE_TORCH=ON (existing PyTorch path) - libmori_xla_ffi_ops.so when ENABLE_XLA_FFI=ON (JAX/XLA path) No separate libmori_core_pybinds.so target. No shmem_host_api.hpp split -- mori already separates host and device code, so the existing shmem_api.hpp is used as-is by both frameworks. Key changes: - Add ENABLE_TORCH, ENABLE_MPI cmake options (default ON) with compile definitions - Guard torch_bootstrap and MPI includes behind their respective #ifdef guards - Add POSITION_INDEPENDENT_CODE to mori_shmem for shared lib linking - Create mori_core.cpp / pybind_core.cpp with shmem + IO bindings - Refactor mori.cpp to delegate core registration to mori_core.cpp - Update python/mori/cpp/__init__.py for library auto-detection Existing PyTorch users are unaffected: ENABLE_TORCH defaults to ON and all existing APIs/symbols remain available in libmori_pybinds.so.
Vendor XLA FFI headers from rocm-xla/rocm-jaxlib-v0.9.0 (ca6c4f848f) and implement MoE dispatch/combine FFI handlers so that JAX can invoke mori kernels via XLA custom calls without any Torch dependency. New ENABLE_XLA_FFI cmake option (default OFF) gates the entire FFI build path. When enabled, builds libmori_xla_ffi_ops.so exporting: - mori_ep_dispatch, mori_ep_combine (MoE forward ops) - mori_ep_dispatch_recv, mori_ep_combine_recv (async recv ops) - mori_ep_reset (barrier reset) - mori_ffi_create_handle, mori_ffi_destroy_handle (lifecycle) Key additions: - 3rdparty/xla_ffi/: vendored c_api.h, api.h, ffi.h (header-only) - src/ffi/mori_xla_ffi_ops.cpp: FFI handler implementations using XLA_FFI_DEFINE_HANDLER_SYMBOL to export C symbols from the .so - src/ffi/mori_xla_ffi_handle_mgr.hpp: thread-safe singleton managing EpDispatchCombineHandle instances by integer ID - python/mori/jax/: Python package with _ffi_registry (ctypes loader + JAX registration), ops.py (EpDispatchCombineOp context manager) - tests/cpp/ffi/test_ffi_contract.cpp: compile-time + link-time contract test for FFI surface stability - tests/python/test_shim_parity.py: runtime check that torch pybind and XLA FFI shims cover the same set of logical MoE operations - .github/CODEOWNERS: require ops team review on FFI/pybind surface Build fix: add POSITION_INDEPENDENT_CODE to mori_ops (required for linking the static .a into libmori_xla_ffi_ops.so). Verified in chao-mori-dev container across 3 configs: - ENABLE_TORCH=OFF, ENABLE_XLA_FFI=ON (XLA-only) - ENABLE_TORCH=ON, ENABLE_XLA_FFI=OFF (Torch-only regression) - ENABLE_TORCH=ON, ENABLE_XLA_FFI=ON (dual)
297de1f to
aa9eaf8
Compare
|
Hi @jhchouuu thanks a lot for your work! I notice you already adopted the changes about host/device code, so you can see we rm that part and my latest rebase doesn't need I would leave the 1st one to @pemeliya about 2nd one on DLPack + Yes, DLPack works on our jax side with In addition, we have added jax/xla FFI header file to let jax xla support at our 2nd commit. Regarding overlapping with your branch, as long as it can land on ToT of mori, and it works on jax-xla side, we don't mind how would you or your team like to proceed :) If your changes can merge more quickly and we can discard overlapped part and rebase our work. The most important to us is to be able to use ToT of mori and xla-ffi header files. (2nd commit aa9eaf8) |
|
Hi @jhchouuu, I agree with Chao, your approach of removing torch deps completely from C++ part is cleaner (vs just adding ifdefs). We did it this way in the first place because did not want to introduce breaking changes on the python side. So, we can wait until your changes are merged. What concerns cuda_array_interface, this really does not matter for us since XLA/JAX custom calls work in a completely different way. Later on, probably we will have to extend/modify the underlying dispatch/combine calls to be able to work with symmetric-allocated output buffers. Note that, in XLA there is no any analogue of Torch::from_blob function since XLA allocates input and output buffers upfront. |
|
Hi @i-chaochen and @pemeliya , thanks for your suggestion.! Now I am focus on this branch, however, since the refactoring scope is relatively large, we will merge it only after extensive testing. This will probably take a few days, and we expect to complete it by Friday (2026.3.6). In the meantime, we will also spend some time learning JAX to ensure that our changes won't cause major disruptions for you. By the way, the main guiding principles of our refactoring are: 1) removing Torch dependencies completely from the C++ part, I think we will no longer use methods like Torch:from_blob either, and 2) separating host-side and device-side code as much as possible, adding a JIT framework where device-side code is JIT-compiled based on different NIC and GPU arch. |
Based on PR #173 by Chao Chen <cchen104@amd.com>, adapted for the refactored architecture (raw pointer args + Python-side kernel launch).
No description provided.