Skip to content

Commit 6cbff17

Browse files
committed
add voxceleb1 / voxceleb2 datasets
1 parent 5d08522 commit 6cbff17

File tree

4 files changed

+287
-3
lines changed

4 files changed

+287
-3
lines changed

init/colab_requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
torchaudio==0.9.0
1+
torchaudio==0.13.0
22
pyyaml==5.4.1
33
scikit-learn==1.0
44
wandb==0.12.4
55
rich==10.12.0
66
umap-learn==0.5.1
77
librosa==0.8.1
8+
joblib==1.1.0

init/requirements.txt

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
black==21.9b0
22
tqdm==4.62.3
3-
torch==1.9.1
4-
torchaudio==0.9.1
3+
torch==1.13.0
4+
torchaudio==0.13.0
55
matplotlib==3.4.3
66
pandas==1.3.3
77
pyyaml==5.4.1
@@ -12,3 +12,4 @@ requests==2.26.0
1212
rich==10.12.0
1313
umap-learn==0.5.1
1414
librosa==0.8.1
15+
joblib==1.1.0

src/datasets.py

+242
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
11
import os
22
import itertools
3+
import csv
4+
import shutil
5+
import logging
6+
from pathlib import Path
37
from collections import defaultdict
48
from functools import partial
59

610
import torch
711
import torchaudio
812
import librosa
913
import numpy as np
14+
import soundfile as sf
15+
from joblib import Parallel, delayed
16+
from librosa.core.audio import __audioread_load as audioread_load
1017
from tqdm import tqdm
18+
from torch.hub import download_url_to_file
19+
from torch.utils.data import Dataset
20+
from torchaudio.datasets.utils import extract_archive
1121

1222
import utils
1323

@@ -352,3 +362,235 @@ def get_path(self, idx):
352362
speaker_id,
353363
f"{speaker_id}_{utterance_id}_{self._mic_id}{self._audio_ext}",
354364
)
365+
366+
367+
class VoxCeleb1Dataset(SpeakerDataset, torchaudio.datasets.VoxCeleb1Identification):
368+
"""
369+
Custom VoxCeleb1 dataset for speaker-related tasks
370+
"""
371+
372+
def __init__(self, root, transforms=None, *args, **kwargs):
373+
if not os.path.exists(root):
374+
os.makedirs(root, exist_ok=True)
375+
kwargs["download"] = True
376+
torchaudio.datasets.VoxCeleb1Identification.__init__(
377+
self, root, *args, **kwargs
378+
)
379+
SpeakerDataset.__init__(self, transforms=transforms)
380+
381+
def get_speakers_utterances(self):
382+
speakers_utterances = defaultdict(list)
383+
for i, file_path in enumerate(self._flist):
384+
speaker_id, _, _ = file_path.split("/")[-3:]
385+
speakers_utterances[speaker_id].append(i)
386+
return speakers_utterances
387+
388+
def get_sample(self, idx):
389+
(
390+
waveform,
391+
sample_rate,
392+
speaker,
393+
_,
394+
) = torchaudio.datasets.VoxCeleb1Identification.__getitem__(self, idx)
395+
return waveform, sample_rate, speaker
396+
397+
def get_path(self, idx):
398+
return self._flist[idx]
399+
400+
401+
class VoxCeleb2(Dataset):
402+
"""
403+
VoxCeleb2 dataset following torchaudio's implementation of VoxCeleb1.
404+
405+
References:
406+
- https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox2.html
407+
- https://pytorch.org/audio/stable/_modules/torchaudio/datasets/voxceleb1.html
408+
"""
409+
410+
SAMPLE_RATE = 16000
411+
# Credentials from https://github.com/UoA-CARES-Student/VoxCeleb2-Dataset
412+
_USERNAME = "voxceleb1912"
413+
_PASSWORD = "0s42xuw6"
414+
_ARCHIVE_CONFIGS = {
415+
"dev": {
416+
"archive_name": "vox2_dev_aac.zip",
417+
"urls": [
418+
"http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partaa",
419+
"http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partab",
420+
"http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partac",
421+
"http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partad",
422+
"http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partae",
423+
"http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partaf",
424+
"http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partag",
425+
"http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partah",
426+
],
427+
"checksums": [None, None, None, None, None, None, None, None],
428+
},
429+
"test": {
430+
"archive_name": "vox2_test_aac.zip",
431+
"url": "http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_test_aac.zip",
432+
"checksum": "e4d9200107a7bc60f0b620d5dc04c3aab66681b649f9c218380ac43c6c722079",
433+
},
434+
}
435+
_IDEN_SPLIT_URL = "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/vox2_meta.csv"
436+
_ext_audio = ".wav"
437+
438+
def __init__(self, root, subset="dev", meta_url=_IDEN_SPLIT_URL, download=False):
439+
if subset not in ["dev", "test"]:
440+
raise ValueError("`subset` must be one of ['dev', 'test']")
441+
root = os.fspath(root)
442+
self._path = os.path.join(root, "wav")
443+
if not os.path.isdir(self._path):
444+
if not download:
445+
raise RuntimeError(
446+
f"Dataset not found at {self._path}. Please set `download=True` to download the dataset."
447+
)
448+
self._download_extract_wavs(root)
449+
450+
# Download the `vox2_meta.csv` file to get the dev and test lists
451+
meta_list_path = os.path.join(root, os.path.basename(meta_url))
452+
if not os.path.exists(meta_list_path):
453+
download_url_to_file(meta_url, meta_list_path)
454+
self._flist = self._get_flist(root, meta_list_path, subset)
455+
456+
def _convert_to_wav(self, root, paths):
457+
"""
458+
Convert .m4a files in the given paths to .wav
459+
"""
460+
461+
def _to_wav(path):
462+
try:
463+
waveform, _ = audioread_load(
464+
path, offset=0.0, duration=None, dtype=np.float32
465+
)
466+
path_wav = os.path.splitext(path)[0] + ".wav"
467+
sf.write(path_wav, waveform, self.SAMPLE_RATE)
468+
except:
469+
logging.warning(f"Could not convert file {path} to .wav.")
470+
os.remove(path)
471+
472+
Parallel(n_jobs=-1, backend="threading")(
473+
delayed(_to_wav)(os.path.join(root, p))
474+
for p in tqdm(paths, desc="Converting audios to .wav")
475+
if p.endswith(".m4a")
476+
)
477+
478+
def _download_extract_wavs(self, root):
479+
"""
480+
Download dataset splits, extract zipped archives
481+
and convert .m4a files to .wav
482+
"""
483+
if not os.path.isdir(root):
484+
os.makedirs(root)
485+
for split, split_config in self._ARCHIVE_CONFIGS.items():
486+
split_name = split_config["archive_name"]
487+
split_path = os.path.join(root, split_name)
488+
# The zip file of dev data is splited to 8 chunks.
489+
# Download and combine them into one file before extraction.
490+
if split == "dev":
491+
urls = split_config["urls"]
492+
checksums = split_config["checksums"]
493+
with open(split_path, "wb") as f:
494+
for url, checksum in zip(urls, checksums):
495+
file_path = os.path.join(root, os.path.basename(url))
496+
utils.download_auth_url_to_file(
497+
url,
498+
file_path,
499+
self._USERNAME,
500+
self._PASSWORD,
501+
hash_prefix=checksum,
502+
)
503+
with open(file_path, "rb") as f_split:
504+
f.write(f_split.read())
505+
elif split == "test":
506+
url = split_config["url"]
507+
checksum = split_config["checksum"]
508+
file_path = os.path.join(root, os.path.basename(url))
509+
utils.download_auth_url_to_file(
510+
url, file_path, self._USERNAME, self._PASSWORD, hash_prefix=checksum
511+
)
512+
extracted_paths = extract_archive(split_path)
513+
self._convert_to_wav(root, extracted_paths)
514+
shutil.move(os.path.join(root, "aac"), os.path.join(root, "wav"))
515+
516+
def _get_flist(self, root, meta_list_path, subset):
517+
"""
518+
Load the full list of files in the given split
519+
"""
520+
f_list = []
521+
with open(meta_list_path, "r") as f:
522+
csv_file = csv.reader(f, delimiter=",")
523+
for line in csv_file:
524+
id, set = line[0].strip(), line[-1].strip()
525+
if set == subset:
526+
f_list += [str(i) for i in Path(root).rglob(f"{id}/**/*.wav")]
527+
return sorted(f_list)
528+
529+
def _get_file_id(self, file_path, _ext_audio):
530+
"""
531+
Return the file identifier as a combination of speaker id,
532+
youtube video id and utterance id
533+
"""
534+
speaker_id, youtube_id, utterance_id = file_path.split("/")[-3:]
535+
utterance_id = utterance_id.replace(_ext_audio, "")
536+
file_id = "-".join([speaker_id, youtube_id, utterance_id])
537+
return file_id
538+
539+
def get_metadata(self, n):
540+
"""
541+
Get metadata for the n-th sample from the dataset.
542+
Returns filepath instead of waveform, but otherwise
543+
returns the same fields as `__getitem__`.
544+
"""
545+
file_path = self._flist[n]
546+
file_id = self._get_file_id(file_path, self._ext_audio)
547+
speaker_id = file_id.split("-")[0]
548+
speaker_id = int(speaker_id[3:])
549+
return file_path, self.SAMPLE_RATE, speaker_id, file_id
550+
551+
def __getitem__(self, n):
552+
"""
553+
Load the n-th sample from the dataset
554+
"""
555+
metadata = self.get_metadata(n)
556+
waveform, sample_rate = torchaudio.load(metadata[0], metadata[1])
557+
if sample_rate != self.SAMPLE_RATE:
558+
raise ValueError(
559+
f"sample rate should be {self.SAMPLE_RATE}, but got {sample_rate}"
560+
)
561+
return (waveform,) + metadata[1:]
562+
563+
def __len__(self):
564+
return len(self._flist)
565+
566+
567+
class VoxCeleb2Dataset(SpeakerDataset, VoxCeleb2):
568+
"""
569+
Custom VoxCeleb2 dataset for speaker-related tasks
570+
"""
571+
572+
def __init__(self, root, transforms=None, *args, **kwargs):
573+
if not os.path.exists(root):
574+
os.makedirs(root, exist_ok=True)
575+
kwargs["download"] = True
576+
VoxCeleb2.__init__(self, root, *args, **kwargs)
577+
SpeakerDataset.__init__(self, transforms=transforms)
578+
579+
def get_speakers_utterances(self):
580+
speakers_utterances = defaultdict(list)
581+
for i, file_path in enumerate(self._flist):
582+
speaker_id, _, _ = file_path.split("/")[-3:]
583+
speakers_utterances[speaker_id].append(i)
584+
return speakers_utterances
585+
586+
def get_sample(self, idx):
587+
(
588+
waveform,
589+
sample_rate,
590+
speaker,
591+
_,
592+
) = VoxCeleb2.__getitem__(self, idx)
593+
return waveform, sample_rate, speaker
594+
595+
def get_path(self, idx):
596+
return self._flist[idx]

src/utils.py

+40
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import datetime
33
import os
44
import string
5+
import hashlib
6+
import requests
57

68
import torch
79
import numpy as np
@@ -10,6 +12,7 @@
1012
import IPython.display as ipd
1113
import wandb
1214
import umap
15+
from tqdm import tqdm
1316
from sklearn.manifold import TSNE
1417
from sklearn.decomposition import TruncatedSVD
1518
from scipy.spatial import ConvexHull
@@ -463,3 +466,40 @@ def chart_dependencies(model, n_mels=80, device="cpu"):
463466
).all() and (
464467
inputs.grad[random_index] != 0
465468
).any(), f"Only index {random_index} should have non-zero gradients"
469+
470+
471+
def download_auth_url_to_file(
472+
url, file_path, username, password, hash_prefix=None, progress=True
473+
):
474+
"""
475+
Download the file at the given URL using the given credentials,
476+
and finally double check the checksum of the downloaded file
477+
"""
478+
if hash_prefix is not None:
479+
sha256 = hashlib.sha256()
480+
response = requests.get(url, auth=(username, password), stream=True)
481+
if response.status_code == 200:
482+
file_size = int(response.headers.get("content-length", 0))
483+
with open(file_path, "wb") as out:
484+
with tqdm(
485+
total=file_size,
486+
disable=not progress,
487+
unit="B",
488+
unit_scale=True,
489+
unit_divisor=1024,
490+
) as pbar:
491+
for buffer in response.iter_content():
492+
out.write(buffer)
493+
if hash_prefix is not None:
494+
sha256.update(buffer)
495+
pbar.update(len(buffer))
496+
if hash_prefix is not None:
497+
digest = sha256.hexdigest()
498+
if digest[: len(hash_prefix)] != hash_prefix:
499+
raise RuntimeError(
500+
f'invalid hash value (expected "{hash_prefix}", got "{digest}")'
501+
)
502+
return True
503+
raise RuntimeError(
504+
f"Couldn't download from url {url}, got response status code {response.status_code}"
505+
)

0 commit comments

Comments
 (0)