diff --git a/src/torchaudio/pipelines/_wav2vec2/utils.py b/src/torchaudio/pipelines/_wav2vec2/utils.py index e690e8103c..65a7a6a2c6 100644 --- a/src/torchaudio/pipelines/_wav2vec2/utils.py +++ b/src/torchaudio/pipelines/_wav2vec2/utils.py @@ -38,7 +38,7 @@ def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[ if self.apply_log_softmax: output = torch.nn.functional.log_softmax(output, dim=-1) if self.append_star: - star_dim = torch.zeros((1, output.size(1), 1), dtype=output.dtype, device=output.device) + star_dim = torch.zeros((output.size(0), output.size(1), 1), dtype=output.dtype, device=output.device) output = torch.cat((output, star_dim), dim=-1) return output, output_lengths