diff --git a/src/torchaudio/pipelines/_wav2vec2/utils.py b/src/torchaudio/pipelines/_wav2vec2/utils.py
index e690e8103c..65a7a6a2c6 100644
--- a/src/torchaudio/pipelines/_wav2vec2/utils.py
+++ b/src/torchaudio/pipelines/_wav2vec2/utils.py
@@ -38,7 +38,7 @@ def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[
         if self.apply_log_softmax:
             output = torch.nn.functional.log_softmax(output, dim=-1)
         if self.append_star:
-            star_dim = torch.zeros((1, output.size(1), 1), dtype=output.dtype, device=output.device)
+            star_dim = torch.zeros((output.size(0), output.size(1), 1), dtype=output.dtype, device=output.device)
             output = torch.cat((output, star_dim), dim=-1)
         return output, output_lengths