Training Speaker_ID SpeechBrain

139 views Asked by At

I am encountering difficulties while attempting to train the SpeechBrain speaker identification model using my own WAV files. I have made some modifications to the existing code provided in the train.yaml file, but I'm not sure if I'm even modifying the train.py and train.yaml files correctly. So far, the code runs well, but the trained model is not in the specified directory. I'm hoping to use this code for speaker diarization.

Here is the tutorial they give: https://colab.research.google.com/drive/1UwisnAjr8nQF3UnrkIJ4abBMAWzVwBMh?usp=sharing

Train.py file: https://github.com/speechbrain/speechbrain/blob/develop/templates/speaker_id/train.py Train.yaml file: https://github.com/speechbrain/speechbrain/blob/develop/templates/speaker_id/train.yaml

So far, I've set the data_folder parameter in the train.yaml file to point to my own data directory. In the train.py file, I've updated the prepare_mini_librispeech function call in the main code section to my own data directory.

sb.utils.distributed.run_on_main(
    prepare_mini_librispeech,
    kwargs={
        "data_folder": "file_path",
        "save_json_train": hparams["train_annotation"],
        "save_json_valid": hparams["valid_annotation"],
        "save_json_test": hparams["test_annotation"],
        "split_ratio": hparams["split_ratio"],
    },
)

I've also modified the dataio_prep function to read my own WAV files from the specified directory and replaced the audio_pipeline function with the following code:

import glob

@sb.utils.data_pipeline.takes("file_path")
@sb.utils.data_pipeline.provides("sig")
def audio_pipeline(file_path):
    """Load the signal, and pass it and its length to the corruption class.
    This is done on the CPU in the `collate_fn`."""
    sig = sb.dataio.dataio.read_audio(file_path)
    return sig

# Define datasets. We also connect the dataset with the data processing
# functions defined above.
datasets = {}
data_info = {
    "train": hparams["train_annotation"],
    "valid": hparams["valid_annotation"],
    "test": hparams["test_annotation"],
}
hparams["dataloader_options"]["shuffle"] = False
for dataset in data_info:
    audio_files = glob.glob(os.path.join("file_path", dataset, "*.wav"))
    datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_csv(
        csv_path=data_info[dataset],
        replacements={"file_path": audio_files},
        dynamic_items=[audio_pipeline, label_pipeline],
        output_keys=["id", "sig", "spk_id_encoded"],
    )

Thank you in advance for your help!

0

There are 0 answers