Skip to content

Added planar types to speed up complex half precision GEMMs#1142

Open
cliffburdick wants to merge 5 commits intomainfrom
planar_tensor
Open

Added planar types to speed up complex half precision GEMMs#1142
cliffburdick wants to merge 5 commits intomainfrom
planar_tensor

Conversation

@cliffburdick
Copy link
Collaborator

No description provided.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 19, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 19, 2026

Greptile Summary

This PR introduces matxFp16ComplexPlanar and matxBf16ComplexPlanar tag types and the infrastructure to use them end-to-end, enabling cuBLAS GEMM to consume pre-converted planar buffers directly rather than allocating temporaries and running interleaved↔planar conversion kernels on every call. The changes touch the tensor core (new PlanarComplexProxy proxy-reference, LoadPlanarComplex/StorePlanarComplex helpers, contiguity validation), the set operator (EPT gating, planar write path), and the cuBLAS GEMM executor (skip-conversion fast path, ldc fix for planar C). Previous review threads around the TotalSize() non-contiguous offset bug, the c_adj pointer mismatch, and the EPT regression are all addressed in this revision.

Key observations:

  • tensor_impl_t const operator() inconsistency: For non-planar types, the const overload returns T& (writable through the pointer), maintaining write-through semantics even on const tensors. For planar types, the new const overload returns T by value via LoadPlanarComplex, which silently discards any assignment. SetOp avoids this today via mutable out_, but any other code holding a const tensor_t<PlanarType> and writing through operator() will produce a silent no-op. The fix is to return PlanarComplexProxy{const_cast<self_type*>(this), offset} from the const overload as well.
  • PlanarComplexProxy::real() / imag() double-load: Each accessor independently calls LoadPlanarComplex, performing two full scalar-pair reads when both components are needed; reading directly from the scalar planes avoids the extra load.
  • planar(ComplexInterleavedOp) overload discoverability: The planar(interleaved(x)) identity shortcut is defined in interleaved.h; users who include only planar.h will silently get a doubly-wrapped operator instead.
  • Redundant assertion in ValidatePlanarLayoutOnCreate_: Unit innermost stride is implied by IsContiguous(), so the first MATX_ASSERT_STR is redundant.
  • The params.ldc = c.Size(RANK-1) override for all complex-half types is correct for contiguous tensors (which planar types enforce) and the temp-buffer non-planar path.
  • Cache-key hash and equality operators are properly updated for the new a_planar/b_planar/c_planar flags to prevent stale plan reuse.

Confidence Score: 3/5

  • Core GEMM fast path and EPT fixes are sound, but the const operator() inconsistency for planar types is a latent silent-correctness hazard that should be addressed before merging.
  • The previous three critical issues are resolved. The remaining concerns are: (1) the const operator() returning a value instead of a writable proxy — currently masked by mutable out_ in SetOp, but a latent trap for any caller using const tensor_t with planar types; (2) minor proxy double-load and discoverability issues. The GEMM path itself is logically correct.
  • include/matx/core/tensor_impl.h — the const operator() path for planar types needs to return a writable proxy rather than a value to preserve write-through semantics consistent with non-planar types.

Important Files Changed

Filename Overview
include/matx/core/tensor_impl.h Adds PlanarComplexProxy for proxy-reference write semantics and LoadPlanarComplex/StorePlanarComplex helpers. The const operator() returns T by value instead of a writable proxy, creating an inconsistency with non-planar types and silently dropping writes in non-SetOp const contexts.
include/matx/operators/set.h Fixes the previously reported EPT regression by gating {ONE, ONE} on is_planar_complex_v; adds planar-output branch to _internal_mapply and operator() that forces DefaultCapabilities; works correctly thanks to mutable out_.
include/matx/transforms/matmul/matmul_cuda.h Skips interleaved↔planar conversions when inputs/output are already planar; sets params.ldc = c.Size(RANK-1) for all complex-half C tensors; updates param struct and hash/eq for cache keying; adds planar-type support to CompatibleGemmCUDATypes. Overall logic is sound given that planar tensors enforce contiguity.
include/matx/core/tensor.h Adds ValidatePlanarLayoutOnCreate_() to all constructors and Reset() variants to enforce contiguity for planar tensors, addressing the earlier TotalSize()-offset concern; contains a minor redundant assertion (unit innermost stride is implied by IsContiguous()).
include/matx/core/half_complex.h Introduces matxFp16ComplexPlanar and matxBf16ComplexPlanar tag types that inherit from their interleaved counterparts; provides copy-constructors and assignment from the base type for interoperability.
include/matx/core/type_utils_both.h Adds is_planar_complex_v<T> trait and propagates planar types through all existing type-trait checks (is_complex, is_complex_half, is_fp16_type, is_bf16_type, is_matx_type); consistent and complete.
include/matx/operators/interleaved.h Adds InnerOp() to ComplexInterleavedOp and defines both interleaved(planar(x)) and planar(interleaved(x)) identity shortcuts; includes planar.h to access ComplexPlanarOp. The planar() overload being placed here rather than in planar.h is a discoverability concern.

Sequence Diagram

sequenceDiagram
    participant User
    participant SetOp
    participant PlanarTensor as tensor_t PlanarType
    participant PlanarProxy as PlanarComplexProxy
    participant MatMulCUDA
    participant cuBLAS

    Note over User,cuBLAS: Element-wise assign to planar output
    User->>SetOp: run(exec)
    SetOp->>SetOp: _internal_mapply(i,j) const with mutable out_
    SetOp->>PlanarTensor: out_.operator() non-const via mutable
    PlanarTensor-->>PlanarProxy: PlanarComplexProxy{this, offset}
    SetOp->>PlanarProxy: proxy = get_value(op_, i, j)
    PlanarProxy->>PlanarTensor: StorePlanarComplex(offset, val)
    PlanarTensor->>PlanarTensor: base[offset]=real, base[offset+N]=imag

    Note over User,cuBLAS: GEMM with pre-planar inputs
    User->>MatMulCUDA: Execute(a_planar, b_planar, c_planar, stream)
    MatMulCUDA->>MatMulCUDA: a_is_planar=true, skip conversion
    MatMulCUDA->>MatMulCUDA: b_is_planar=true, skip conversion
    MatMulCUDA->>MatMulCUDA: c_is_planar=true, c_adj.Reset(c.Data())
    MatMulCUDA->>MatMulCUDA: params.ldc = c.Size(RANK-1)
    MatMulCUDA->>cuBLAS: cublasGemmEx CUDA_C_16F planar pointers
    cuBLAS-->>MatMulCUDA: writes planar result to c.Data() directly
    MatMulCUDA->>MatMulCUDA: c_is_planar=true, skip interleaved conversion
    MatMulCUDA-->>User: done

    Note over User,cuBLAS: Read back planar to interleaved
    User->>SetOp: run(exec) for c_interleaved = c_planar
    SetOp->>PlanarTensor: get_value(c_planar, i, j) const
    PlanarTensor->>PlanarTensor: LoadPlanarComplex(offset)
    PlanarTensor-->>SetOp: T with real=base[off], imag=base[off+N]
    SetOp->>SetOp: store to c_interleaved(i,j)
Loading

Comments Outside Diff (4)

  1. include/matx/core/tensor_impl.h, line 1344-1354 (link)

    Const operator() returns value, not proxy — silent write loss outside SetOp

    The const overload of operator() for planar types returns T by value (via LoadPlanarComplex), while for non-planar types it returns T& (a writable reference through data_.ldata_). This asymmetry means that any caller holding a const tensor_t<PlanarType> and attempting a write will silently discard the value.

    SetOp sidesteps this today via mutable out_, but any code that passes a planar tensor as a const & and then writes to it through operator() (a common pattern in MATX's operator infrastructure) will silently produce a no-op. For example, a kernel functor that captures a planar output tensor by const-copy and tries to assign:

    // const tensor_impl_t<matxFp16ComplexPlanar, ...>::operator()<CapType>(i, j)
    return LoadPlanarComplex(offset); // ← returns T, not T& or proxy
    // caller does: value = rhs; // ← no-op, value is a temporary

    Non-planar tensors work in const context because data_.ldata_[offset] returns T& (writes through the pointer are not blocked by const on the descriptor). The fix for planar types is to return a PlanarComplexProxy from the const overload as well, using const_cast<self_type*>(this) — since StorePlanarComplex only writes through the pointer (data_.ldata_), not through the object itself, this is semantically safe and matches the non-planar pattern.

    // Suggested fix: return a writable proxy even from the const overload
    if constexpr (is_planar_complex_v<T>) {
      return PlanarComplexProxy{const_cast<self_type*>(this), offset};
    }
  2. include/matx/core/tensor_impl.h, line 1758-1784 (link)

    PlanarComplexProxy::real() / imag() each trigger a full double-plane load

    Both real() and imag() accessors on the proxy call LoadPlanarComplex(offset) independently. LoadPlanarComplex reads both the real scalar (base[offset]) and imaginary scalar (base[offset + total]) to construct a full T. If a caller reads both components (e.g. proxy.real() then proxy.imag()), two full pair-of-reads occur instead of one.

    Consider reading each scalar directly from the planes:

    __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto real() const
    {
      using Scalar = typename T::value_type;
      return reinterpret_cast<const Scalar *>(self->data_.ldata_)[offset];
    }
    
    __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto imag() const
    {
      using Scalar = typename T::value_type;
      return reinterpret_cast<const Scalar *>(self->data_.ldata_)[offset + self->TotalSize()];
    }

    This accesses only the needed scalar and avoids loading the other plane.

  3. include/matx/operators/interleaved.h, line 247-263 (link)

    planar() overload placed in interleaved.h — discoverable only if both headers are included

    The planar(ComplexInterleavedOp) shortcut is defined here in interleaved.h, but users who include only planar.h (without interleaved.h) will not see this overload and will accidentally construct a double-wrapped operator (ComplexPlanarOp<ComplexInterleavedOp<T>>) instead of getting the inner op back.

    Both shortcut overloads (interleaved(planar(x)) and planar(interleaved(x))) should live in a single location — either both in planar.h (with interleaved.h included there), or both in a new combined header — rather than split across two files.

  4. include/matx/core/tensor.h, line 1537-1548 (link)

    Redundant assertion in ValidatePlanarLayoutOnCreate_

    The first assertion (Stride(RANK - 1) == 1) is already subsumed by the second (IsContiguous()). A contiguous tensor by definition has unit innermost stride, so the first check never fails independently of the second. Consider removing it:

Last reviewed commit: "Compilation error"

@cliffburdick
Copy link
Collaborator Author

/build

1 similar comment
@cliffburdick
Copy link
Collaborator Author

/build

@cliffburdick
Copy link
Collaborator Author

/build

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant