Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions csrc/selective_scan/selective_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ void set_ssm_params_fwd(SSMParamsBase &params,
void* delta_bias_ptr,
void* x_ptr,
bool has_z,
bool delta_softplus) {
bool delta_softplus,
const at::Tensor initial_state) {

// Reset the parameters
memset(&params, 0, sizeof(params));
Expand Down Expand Up @@ -138,6 +139,18 @@ void set_ssm_params_fwd(SSMParamsBase &params,
}
params.out_batch_stride = out.stride(0);
params.out_d_stride = out.stride(1);

// Set initial state if provided
params.initial_state_ptr = initial_state.defined() ? initial_state.data_ptr() : nullptr;
if (initial_state.defined()) {
params.initial_state_batch_stride = initial_state.stride(0);
params.initial_state_d_stride = initial_state.stride(1);
params.initial_state_dstate_stride = initial_state.stride(2);
} else {
params.initial_state_batch_stride = 0;
params.initial_state_d_stride = 0;
params.initial_state_dstate_stride = 0;
}
}

void set_ssm_params_bwd(SSMParamsBwd &params,
Expand Down Expand Up @@ -181,7 +194,7 @@ void set_ssm_params_bwd(SSMParamsBwd &params,
// If not recompute_out_z, pass dout instead of out_z.
// This won't be used by the bwd kernel
recompute_out_z ? out_z : dout,
D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus);
D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus, at::Tensor());
if (!recompute_out_z) { params.out_z_ptr = nullptr; }

// Set the pointers and strides.
Expand Down Expand Up @@ -229,7 +242,8 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
const c10::optional<at::Tensor> &D_,
const c10::optional<at::Tensor> &z_,
const c10::optional<at::Tensor> &delta_bias_,
bool delta_softplus) {
bool delta_softplus,
const c10::optional<at::Tensor> &initial_state_ = c10::nullopt) {
auto input_type = u.scalar_type();
auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
Expand Down Expand Up @@ -293,6 +307,15 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
CHECK_SHAPE(delta_bias, dim);
}

if (initial_state_.has_value()) {
auto initial_state = initial_state_.value();
TORCH_CHECK(initial_state.scalar_type() == weight_type);
TORCH_CHECK(initial_state.is_cuda());
TORCH_CHECK(initial_state.dim() == 3);
CHECK_SHAPE(initial_state, batch_size, dim, dstate);
TORCH_CHECK(initial_state.stride(-1) == 1 || initial_state.size(-1) == 1);
}

at::Tensor z, out_z;
const bool has_z = z_.has_value();
if (has_z) {
Expand All @@ -319,7 +342,8 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
x.data_ptr(),
has_z,
delta_softplus);
delta_softplus,
initial_state_.has_value() ? initial_state_.value() : at::Tensor());

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
Expand Down
4 changes: 4 additions & 0 deletions csrc/selective_scan/selective_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ struct SSMParamsBase {
void *__restrict__ x_ptr;
void *__restrict__ z_ptr;
void *__restrict__ out_z_ptr;
void *__restrict__ initial_state_ptr; // Optional initial state (batch, dim, dstate)
index_t initial_state_batch_stride;
index_t initial_state_d_stride;
index_t initial_state_dstate_stride;
};

struct SSMParamsBwd: public SSMParamsBase {
Expand Down
40 changes: 36 additions & 4 deletions csrc/selective_scan/selective_scan_fwd_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
bool kIsVariableB_, bool kIsVariableC_,
bool kHasZ_, typename input_t_, typename weight_t_>
bool kHasZ_, bool kHasInitialState_, typename input_t_, typename weight_t_>
struct Selective_Scan_fwd_kernel_traits {
static_assert(kNItems_ % 4 == 0);
using input_t = input_t_;
Expand All @@ -43,6 +43,7 @@ struct Selective_Scan_fwd_kernel_traits {
static constexpr bool kIsVariableB = kIsVariableB_;
static constexpr bool kIsVariableC = kIsVariableC_;
static constexpr bool kHasZ = kHasZ_;
static constexpr bool kHasInitialState = kHasInitialState_;

static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;

Expand Down Expand Up @@ -76,6 +77,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
constexpr bool kIsVariableB = Ktraits::kIsVariableB;
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
constexpr bool kHasZ = Ktraits::kHasZ;
constexpr bool kHasInitialState = Ktraits::kHasInitialState;
constexpr int kNThreads = Ktraits::kNThreads;
constexpr int kNItems = Ktraits::kNItems;
constexpr int kNRows = Ktraits::kNRows;
Expand Down Expand Up @@ -218,8 +220,21 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
#pragma unroll
for (int i = 0; i < kNItems; ++i) {
if constexpr (!kIsComplex) {
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
weight_t delta_a = exp2f(delta_vals[r][i] * A_val[r]);
weight_t delta_b_u = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i];

if constexpr (kHasInitialState) {
if (chunk == 0 && i == 0 && threadIdx.x == 0) {
const weight_t *initial_state = reinterpret_cast<const weight_t *>(params.initial_state_ptr)
+ batch_id * params.initial_state_batch_stride
+ dim_id * params.initial_state_d_stride;
weight_t h0_val = initial_state[state_idx * params.initial_state_dstate_stride];
// Modify: deltaB[0]*u[0] -> deltaA[0]*h0 + deltaB[0]*u[0]
delta_b_u = delta_a * h0_val + delta_b_u;
}
}

thread_data[i] = make_float2(delta_a, delta_b_u);
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
thread_data[i] = make_float2(1.f, 0.f);
Expand All @@ -229,6 +244,21 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
// Pytorch's implementation of complex exp (which calls thrust) is very slow
complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]);
weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i];

// Incorporate initial state for chunk 0, first timestep (complex case)
if constexpr (kHasInitialState) {
if (chunk == 0 && i == 0 && threadIdx.x == 0) {
// For complex, initial_state is stored as complex_t (interleaved real/imag)
const complex_t *initial_state_complex = reinterpret_cast<const complex_t *>(params.initial_state_ptr)
+ batch_id * (params.initial_state_batch_stride / 2)
+ dim_id * (params.initial_state_d_stride / 2);
complex_t h0_val = initial_state_complex[state_idx * (params.initial_state_dstate_stride / 2)];
complex_t h0_contrib = delta_a_exp * h0_val;
// B_delta_u_val is already complex_t, add h0_contrib
B_delta_u_val = h0_contrib + B_delta_u_val;
}
}

thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
Expand Down Expand Up @@ -316,7 +346,8 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, input_t, weight_t>;
BOOL_SWITCH(params.initial_state_ptr != nullptr, kHasInitialState, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kHasInitialState, input_t, weight_t>;

constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
dim3 grid(params.batch, params.dim / kNRows);
Expand All @@ -341,6 +372,7 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {

kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
Expand Down
27 changes: 20 additions & 7 deletions mamba_ssm/modules/mamba_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,21 +160,32 @@ def forward(self, hidden_states, inference_params=None):
)
else:
x, z = xz.chunk(2, dim=1)
# Compute short convolution
# Compute short convolution, state continuity logic is inference-only
if conv_state is not None:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
k = self.d_conv - 1
conv_inputs = conv_state[:, :, -k:]
x_input = torch.cat([conv_inputs, x], dim=2)
conv_state.copy_(F.pad(x_input, (self.d_conv - x_input.shape[-1], 0))[:, :, -self.d_conv:]) # Update state (B D W)
else:
x_input = x
if causal_conv1d_fn is None:
x = self.act(self.conv1d(x)[..., :seqlen])
x_conv = self.conv1d(x_input)
if conv_state is not None:
x = self.act(x_conv[:, :, k:k+seqlen])
else:
x = self.act(x_conv[..., :seqlen])
else:
assert self.activation in ["silu", "swish"]
x = causal_conv1d_fn(
x=x,
x_conv = causal_conv1d_fn(
x=x_input,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
)
if conv_state is not None:
x = x_conv[:, :, k:k+seqlen]
else:
x = x_conv

# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
Expand All @@ -186,6 +197,7 @@ def forward(self, hidden_states, inference_params=None):
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
assert self.activation in ["silu", "swish"]
# Ability to pass initial state to kernel in inference - it will incorporate exp(deltaA[0]) * h0 into the first state
y = selective_scan_fn(
x,
dt,
Expand All @@ -197,6 +209,7 @@ def forward(self, hidden_states, inference_params=None):
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=ssm_state is not None,
initial_state=ssm_state, # Kernel will incorporate this as exp(deltaA[0]) * h0
)
if ssm_state is not None:
y, last_state = y
Expand Down
12 changes: 8 additions & 4 deletions mamba_ssm/ops/selective_scan_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class SelectiveScanFn(torch.autograd.Function):

@staticmethod
def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
return_last_state=False, initial_state=None):
if u.stride(-1) != 1:
u = u.contiguous()
if delta.stride(-1) != 1:
Expand All @@ -37,13 +37,15 @@ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softp
C = C.contiguous()
if z is not None and z.stride(-1) != 1:
z = z.contiguous()
if initial_state is not None and initial_state.stride(-1) != 1:
initial_state = initial_state.contiguous()
if B.dim() == 3:
B = rearrange(B, "b dstate l -> b 1 dstate l")
ctx.squeeze_B = True
if C.dim() == 3:
C = rearrange(C, "b dstate l -> b 1 dstate l")
ctx.squeeze_C = True
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, initial_state)
ctx.delta_softplus = delta_softplus
ctx.has_z = z is not None
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
Expand Down Expand Up @@ -104,12 +106,14 @@ def rms_norm_forward(


def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
return_last_state=False, initial_state=None):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
not considered in the backward pass.

initial_state: Optional (batch, dim, dstate) initial SSM state
"""
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, initial_state)


def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
Expand Down
95 changes: 95 additions & 0 deletions tests/test_mamba_chunk_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
import pytest

from mamba_ssm.models.mixer_seq_simple import MixerModel
from mamba_ssm.utils.generation import InferenceParams


def _make_mamba(d_model=32, n_layers=2, d_state=16, vocab_size=100, device="cuda", dtype=torch.float32):
"""Create a simple Mamba model for testing."""
model = MixerModel(
d_model=d_model,
n_layer=n_layers,
d_intermediate=0, # No MLP for simplicity
vocab_size=vocab_size,
ssm_cfg=dict(layer="Mamba1"),
device=device,
dtype=dtype,
)
model.eval()
return model


def _empty_caches_for_model(model, batch_size, device, dtype):
"""Create empty inference caches for a model."""
max_seqlen = 1024 # Large enough for tests
caches = model.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
# Initialize inference params
inference_params = InferenceParams(
max_seqlen=max_seqlen,
max_batch_size=batch_size,
seqlen_offset=0,
)
inference_params.key_value_memory_dict = caches
return inference_params


@pytest.mark.parametrize("device", ["cuda"])
def test_one_forward_matches_two_state_continuity_forward(device):
"""Test that processing two chunks with state continuity matches full forward pass."""
torch.manual_seed(0)
B, L, D = 2, 30, 32
model = _make_mamba(d_model=D, n_layers=2, d_state=16, device=device, dtype=torch.float32)

vocab_size = 100
x = torch.randint(0, vocab_size, (B, L), device=device, dtype=torch.long)
L1 = L // 2

with torch.no_grad():
# Full forward pass
gold = model(x) # (B, L, D)

# Process in two chunks with state continuity
# Use seqlen_offset=0 to use parallel scan with initial_state support
inference_params = _empty_caches_for_model(model, B, device, torch.float32)
y1 = model(x[:, :L1], inference_params=inference_params) # (B, L1, D)
# State is updated in inference_params.key_value_memory_dict
# Process second chunk with same inference_params (state continuity)
y2 = model(x[:, L1:], inference_params=inference_params) # (B, L-L1, D)
got = torch.cat([y1, y2], dim=1)

assert torch.allclose(got, gold, atol=1e-5, rtol=1e-5)


@pytest.mark.parametrize("device", ["cuda"])
def test_forward_matches_steps(device):
"""Test that processing two chunks with state continuity matches sequential step-by-step processing."""
torch.manual_seed(0)
B, L, D = 2, 30, 32
model = _make_mamba(d_model=D, n_layers=2, d_state=16, device=device, dtype=torch.float32)

vocab_size = 100
x = torch.randint(0, vocab_size, (B, L), device=device, dtype=torch.long)
L1 = L // 2

with torch.no_grad():
# Sequential step-by-step processing using forward with seqlen_offset > 0
# This triggers step() method internally
inference_params_seq = _empty_caches_for_model(model, B, device, torch.float32)
y_list = []
for t in range(L):
x_t = x[:, t:t+1] # (B, 1)
# seqlen_offset > 0 triggers step() method internally
inference_params_seq.seqlen_offset = t
y_t = model(x_t, inference_params=inference_params_seq) # (B, 1, D)
y_list.append(y_t)
logits_step = torch.cat(y_list, dim=1) # (B, L, D)

# Process in two chunks using forward with seqlen_offset=0
# This uses parallel scan with initial_state support
inference_params_chunk = _empty_caches_for_model(model, B, device, torch.float32)
y1 = model(x[:, :L1], inference_params=inference_params_chunk) # (B, L1, D)
y2 = model(x[:, L1:], inference_params=inference_params_chunk) # (B, L-L1, D)
logits_chunk = torch.cat([y1, y2], dim=1)

assert torch.allclose(logits_chunk, logits_step, atol=1e-5, rtol=1e-5)