Skip to content

[Bug] TRT-LLM Backend: Incorrect mask and encoder_output_lengths causing random repetitions in batch > 1 #32

@zzgnb

Description

@zzgnb

Description

When testing the TRT-LLM backend with runtime/triton_tensorrt/scripts/infer.py, I noticed that batches with a size > 1 produce incorrect outputs. Shorter sentences in the batch may get repeated continuously instead of being transcribed correctly.

Evidence

As shown below, shorter inputs result in repetitive and incorrect transcriptions.

Image

Root Cause Analysis

I've traced the issue to incorrect encoder_output_lengths from the TRT-LLM backend. Compared to the PyTorch backend, the lengths are wrong, which points to a problem with how padding or masking is handled within the TensorRT engine.

  • Expected encoder_output_lengths (PyTorch):
    [103, 78, 309, 236, 178]
    
  • Actual encoder_output_lengths (TRT-LLM):
    [103, 309, 309, 309, 309]
    

Context

I encountered a similar issue before with the FireRedASR TRT conversion. The old implementation would crash, but the current official Docker image runs without errors—it just produces the wrong output due to the same internal padding problem.

Proposed Solution

My proposed solution is to move the encoder's padding and mask calculation logic out of the model graph and into a preprocessing step. This ensures the correct mask is computed before being passed to the TensorRT engine.

Here are the code changes I made:


1. Externalize Preprocessing Logic

First, I created a new preprocess function to handle padding and mask generation externally.

def preprocess(self, padded_input, input_lengths):
    # Add context padding
    pad_zeros = torch.zeros(len(padded_input), 6, 80, dtype=torch.float32, device=padded_input.device)
    padded_input = torch.cat((padded_input, pad_zeros), dim=1)

    # Generate mask based on original input lengths
    N, T = padded_input.size()[:2]
    mask = torch.ones((N, T), dtype=torch.bool, device=padded_input.device)
    for i in range(N):
        mask[i, input_lengths[i]:] = 0
    mask = mask.unsqueeze(1) # Shape: [N, 1, T]

    return padded_input, mask

2. Modify conformer_encoder.forward

Next, I updated the encoder's forward method to accept the pre-computed mask and removed the internal padding logic.

# In conformer_encoder.forward
def forward(self, padded_input, src_mask):
    # COMMENT OUT or REMOVE internal padding/masking
    # if pad:
    #     padded_input = F.pad(padded_input,
    #         (0, 0, 0, self.input_preprocessor.context - 1), 'constant', 0.0)
    # src_mask = self.padding_position_is_0(padded_input, input_lengths)

    embed_output, src_mask = self.input_preprocessor(padded_input, src_mask)
    enc_output = self.dropout(embed_output)
    pos_emb = self.dropout(self.positional_encoding(embed_output))

    enc_outputs = []
    for enc_layer in self.layer_stack:
        enc_output = enc_layer(enc_output, pos_emb, slf_attn_mask=src_mask,
                               pad_mask=src_mask)
        enc_outputs.append(enc_output)

    return enc_output, src_mask

3. Adjust conformer_encoder.input_preprocessor.forward

Finally, I adjusted input_preprocessor.forward to correctly slice the external mask after subsampling.

# In conformer_encoder.input_preprocessor.forward
def forward(self, x, x_mask):
    x = x.unsqueeze(1)
    x = self.conv(x)
    x = x.transpose(1, 2).flatten(2)
    x = self.out(x)

    # Adjust mask slicing to match subsampling
    # OLD LOGIC:
    # mask = x_mask[:, :, :-2:2][:, :, :-2:2]
    # NEW LOGIC:
    mask = x_mask[:, :, :x_mask.shape[-1]-2:2]
    mask = mask[:, :, :mask.shape[-1]-2:2]

    # No longer need to calculate input_lengths here
    return x, mask

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions