diff --git a/nitrogen/mm_tokenizers.py b/nitrogen/mm_tokenizers.py index e1e6504..f393216 100644 --- a/nitrogen/mm_tokenizers.py +++ b/nitrogen/mm_tokenizers.py @@ -157,6 +157,8 @@ def _prepare_action(self, data: dict): ), f"Action dim {n_action_dims} exceeds max allowed {self.max_action_dim}." # Pad the channel dimension + # [batch, 18, 21] -> [batch, 18, 25] badded by zeros on the right. + # Action data is left aligned here [btn[0], btn[1], ..., btn[16], j_left[0], j_left[1], j_right[0], j_right[1], 0, 0, 0, 0] actions = np.pad(actions, ((0, 0), (0, self.max_action_dim - n_action_dims)), "constant") # Create mask: [T, max_action_dim] @@ -239,6 +241,10 @@ def unpack_actions(self, actions): j_right = actions[:, :, 2:4] buttons = actions[:, :, 4:] else: + # Action data is still left aligned here [btn[0], btn[1], ..., btn[16], j_left[0], j_left[1], j_right[0], j_right[1], 0, 0, 0, 0]. + # And the dimensions are [batch, 18, 25], not [batch, 18, 21]. + # Need to remove padding before using negative indices for extracting j_left, j_right. + actions = actions[:, :, :21] # Unpack the actions into j_left, j_right, buttons buttons = actions[:, :, :-4] j_left = actions[:, :, -4:-2]