-
Notifications
You must be signed in to change notification settings - Fork 23
Description
Description
When testing the TRT-LLM backend with runtime/triton_tensorrt/scripts/infer.py, I noticed that batches with a size > 1 produce incorrect outputs. Shorter sentences in the batch may get repeated continuously instead of being transcribed correctly.
Evidence
As shown below, shorter inputs result in repetitive and incorrect transcriptions.
Root Cause Analysis
I've traced the issue to incorrect encoder_output_lengths from the TRT-LLM backend. Compared to the PyTorch backend, the lengths are wrong, which points to a problem with how padding or masking is handled within the TensorRT engine.
- Expected
encoder_output_lengths(PyTorch):[103, 78, 309, 236, 178] - Actual
encoder_output_lengths(TRT-LLM):[103, 309, 309, 309, 309]
Context
I encountered a similar issue before with the FireRedASR TRT conversion. The old implementation would crash, but the current official Docker image runs without errors—it just produces the wrong output due to the same internal padding problem.
Proposed Solution
My proposed solution is to move the encoder's padding and mask calculation logic out of the model graph and into a preprocessing step. This ensures the correct mask is computed before being passed to the TensorRT engine.
Here are the code changes I made:
1. Externalize Preprocessing Logic
First, I created a new preprocess function to handle padding and mask generation externally.
def preprocess(self, padded_input, input_lengths):
# Add context padding
pad_zeros = torch.zeros(len(padded_input), 6, 80, dtype=torch.float32, device=padded_input.device)
padded_input = torch.cat((padded_input, pad_zeros), dim=1)
# Generate mask based on original input lengths
N, T = padded_input.size()[:2]
mask = torch.ones((N, T), dtype=torch.bool, device=padded_input.device)
for i in range(N):
mask[i, input_lengths[i]:] = 0
mask = mask.unsqueeze(1) # Shape: [N, 1, T]
return padded_input, mask2. Modify conformer_encoder.forward
Next, I updated the encoder's forward method to accept the pre-computed mask and removed the internal padding logic.
# In conformer_encoder.forward
def forward(self, padded_input, src_mask):
# COMMENT OUT or REMOVE internal padding/masking
# if pad:
# padded_input = F.pad(padded_input,
# (0, 0, 0, self.input_preprocessor.context - 1), 'constant', 0.0)
# src_mask = self.padding_position_is_0(padded_input, input_lengths)
embed_output, src_mask = self.input_preprocessor(padded_input, src_mask)
enc_output = self.dropout(embed_output)
pos_emb = self.dropout(self.positional_encoding(embed_output))
enc_outputs = []
for enc_layer in self.layer_stack:
enc_output = enc_layer(enc_output, pos_emb, slf_attn_mask=src_mask,
pad_mask=src_mask)
enc_outputs.append(enc_output)
return enc_output, src_mask3. Adjust conformer_encoder.input_preprocessor.forward
Finally, I adjusted input_preprocessor.forward to correctly slice the external mask after subsampling.
# In conformer_encoder.input_preprocessor.forward
def forward(self, x, x_mask):
x = x.unsqueeze(1)
x = self.conv(x)
x = x.transpose(1, 2).flatten(2)
x = self.out(x)
# Adjust mask slicing to match subsampling
# OLD LOGIC:
# mask = x_mask[:, :, :-2:2][:, :, :-2:2]
# NEW LOGIC:
mask = x_mask[:, :, :x_mask.shape[-1]-2:2]
mask = mask[:, :, :mask.shape[-1]-2:2]
# No longer need to calculate input_lengths here
return x, mask