For mms-tts-eng model I am getting ushort format error

197 views Asked by At
from transformers import VitsModel, AutoTokenizer
import torch

model = VitsModel.from_pretrained("facebook/mms-tts-eng")
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")

text = "some example text in the English language"
inputs = tokenizer(text, return_tensors="pt")

with torch.no_grad():
  output = model(**inputs).waveform

import scipy
scipy.io.wavfile.write("techno.wav", rate=model.config.sampling_rate, 
data=output.cpu().float().numpy())

I am getting this:

error: ushort format requires 0 <= number <= (0x7fff * 2 + 1)
1

There are 1 answers

0
Salad On

This answer here solved this issue for me. So simply transposing your waveform output should fix this.

So instead of doing this:

import scipy
scipy.io.wavfile.write("techno.wav", rate=model.config.sampling_rate, 
data=output.cpu().float().numpy())

I did:

import scipy
scipy.io.wavfile.write("techno.wav", rate=model.config.sampling_rate, 
data=output.cpu().float().numpy().T)