|
1 | 1 | import os
|
2 | 2 | import itertools
|
| 3 | +import csv |
| 4 | +import shutil |
| 5 | +import logging |
| 6 | +from pathlib import Path |
3 | 7 | from collections import defaultdict
|
4 | 8 | from functools import partial
|
5 | 9 |
|
6 | 10 | import torch
|
7 | 11 | import torchaudio
|
8 | 12 | import librosa
|
9 | 13 | import numpy as np
|
| 14 | +import soundfile as sf |
| 15 | +from joblib import Parallel, delayed |
| 16 | +from librosa.core.audio import __audioread_load as audioread_load |
10 | 17 | from tqdm import tqdm
|
| 18 | +from torch.hub import download_url_to_file |
| 19 | +from torch.utils.data import Dataset |
| 20 | +from torchaudio.datasets.utils import extract_archive |
11 | 21 |
|
12 | 22 | import utils
|
13 | 23 |
|
@@ -352,3 +362,235 @@ def get_path(self, idx):
|
352 | 362 | speaker_id,
|
353 | 363 | f"{speaker_id}_{utterance_id}_{self._mic_id}{self._audio_ext}",
|
354 | 364 | )
|
| 365 | + |
| 366 | + |
| 367 | +class VoxCeleb1Dataset(SpeakerDataset, torchaudio.datasets.VoxCeleb1Identification): |
| 368 | + """ |
| 369 | + Custom VoxCeleb1 dataset for speaker-related tasks |
| 370 | + """ |
| 371 | + |
| 372 | + def __init__(self, root, transforms=None, *args, **kwargs): |
| 373 | + if not os.path.exists(root): |
| 374 | + os.makedirs(root, exist_ok=True) |
| 375 | + kwargs["download"] = True |
| 376 | + torchaudio.datasets.VoxCeleb1Identification.__init__( |
| 377 | + self, root, *args, **kwargs |
| 378 | + ) |
| 379 | + SpeakerDataset.__init__(self, transforms=transforms) |
| 380 | + |
| 381 | + def get_speakers_utterances(self): |
| 382 | + speakers_utterances = defaultdict(list) |
| 383 | + for i, file_path in enumerate(self._flist): |
| 384 | + speaker_id, _, _ = file_path.split("/")[-3:] |
| 385 | + speakers_utterances[speaker_id].append(i) |
| 386 | + return speakers_utterances |
| 387 | + |
| 388 | + def get_sample(self, idx): |
| 389 | + ( |
| 390 | + waveform, |
| 391 | + sample_rate, |
| 392 | + speaker, |
| 393 | + _, |
| 394 | + ) = torchaudio.datasets.VoxCeleb1Identification.__getitem__(self, idx) |
| 395 | + return waveform, sample_rate, speaker |
| 396 | + |
| 397 | + def get_path(self, idx): |
| 398 | + return self._flist[idx] |
| 399 | + |
| 400 | + |
| 401 | +class VoxCeleb2(Dataset): |
| 402 | + """ |
| 403 | + VoxCeleb2 dataset following torchaudio's implementation of VoxCeleb1. |
| 404 | +
|
| 405 | + References: |
| 406 | + - https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox2.html |
| 407 | + - https://pytorch.org/audio/stable/_modules/torchaudio/datasets/voxceleb1.html |
| 408 | + """ |
| 409 | + |
| 410 | + SAMPLE_RATE = 16000 |
| 411 | + # Credentials from https://github.com/UoA-CARES-Student/VoxCeleb2-Dataset |
| 412 | + _USERNAME = "voxceleb1912" |
| 413 | + _PASSWORD = "0s42xuw6" |
| 414 | + _ARCHIVE_CONFIGS = { |
| 415 | + "dev": { |
| 416 | + "archive_name": "vox2_dev_aac.zip", |
| 417 | + "urls": [ |
| 418 | + "http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partaa", |
| 419 | + "http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partab", |
| 420 | + "http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partac", |
| 421 | + "http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partad", |
| 422 | + "http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partae", |
| 423 | + "http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partaf", |
| 424 | + "http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partag", |
| 425 | + "http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partah", |
| 426 | + ], |
| 427 | + "checksums": [None, None, None, None, None, None, None, None], |
| 428 | + }, |
| 429 | + "test": { |
| 430 | + "archive_name": "vox2_test_aac.zip", |
| 431 | + "url": "http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_test_aac.zip", |
| 432 | + "checksum": "e4d9200107a7bc60f0b620d5dc04c3aab66681b649f9c218380ac43c6c722079", |
| 433 | + }, |
| 434 | + } |
| 435 | + _IDEN_SPLIT_URL = "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/vox2_meta.csv" |
| 436 | + _ext_audio = ".wav" |
| 437 | + |
| 438 | + def __init__(self, root, subset="dev", meta_url=_IDEN_SPLIT_URL, download=False): |
| 439 | + if subset not in ["dev", "test"]: |
| 440 | + raise ValueError("`subset` must be one of ['dev', 'test']") |
| 441 | + root = os.fspath(root) |
| 442 | + self._path = os.path.join(root, "wav") |
| 443 | + if not os.path.isdir(self._path): |
| 444 | + if not download: |
| 445 | + raise RuntimeError( |
| 446 | + f"Dataset not found at {self._path}. Please set `download=True` to download the dataset." |
| 447 | + ) |
| 448 | + self._download_extract_wavs(root) |
| 449 | + |
| 450 | + # Download the `vox2_meta.csv` file to get the dev and test lists |
| 451 | + meta_list_path = os.path.join(root, os.path.basename(meta_url)) |
| 452 | + if not os.path.exists(meta_list_path): |
| 453 | + download_url_to_file(meta_url, meta_list_path) |
| 454 | + self._flist = self._get_flist(root, meta_list_path, subset) |
| 455 | + |
| 456 | + def _convert_to_wav(self, root, paths): |
| 457 | + """ |
| 458 | + Convert .m4a files in the given paths to .wav |
| 459 | + """ |
| 460 | + |
| 461 | + def _to_wav(path): |
| 462 | + try: |
| 463 | + waveform, _ = audioread_load( |
| 464 | + path, offset=0.0, duration=None, dtype=np.float32 |
| 465 | + ) |
| 466 | + path_wav = os.path.splitext(path)[0] + ".wav" |
| 467 | + sf.write(path_wav, waveform, self.SAMPLE_RATE) |
| 468 | + except: |
| 469 | + logging.warning(f"Could not convert file {path} to .wav.") |
| 470 | + os.remove(path) |
| 471 | + |
| 472 | + Parallel(n_jobs=-1, backend="threading")( |
| 473 | + delayed(_to_wav)(os.path.join(root, p)) |
| 474 | + for p in tqdm(paths, desc="Converting audios to .wav") |
| 475 | + if p.endswith(".m4a") |
| 476 | + ) |
| 477 | + |
| 478 | + def _download_extract_wavs(self, root): |
| 479 | + """ |
| 480 | + Download dataset splits, extract zipped archives |
| 481 | + and convert .m4a files to .wav |
| 482 | + """ |
| 483 | + if not os.path.isdir(root): |
| 484 | + os.makedirs(root) |
| 485 | + for split, split_config in self._ARCHIVE_CONFIGS.items(): |
| 486 | + split_name = split_config["archive_name"] |
| 487 | + split_path = os.path.join(root, split_name) |
| 488 | + # The zip file of dev data is splited to 8 chunks. |
| 489 | + # Download and combine them into one file before extraction. |
| 490 | + if split == "dev": |
| 491 | + urls = split_config["urls"] |
| 492 | + checksums = split_config["checksums"] |
| 493 | + with open(split_path, "wb") as f: |
| 494 | + for url, checksum in zip(urls, checksums): |
| 495 | + file_path = os.path.join(root, os.path.basename(url)) |
| 496 | + utils.download_auth_url_to_file( |
| 497 | + url, |
| 498 | + file_path, |
| 499 | + self._USERNAME, |
| 500 | + self._PASSWORD, |
| 501 | + hash_prefix=checksum, |
| 502 | + ) |
| 503 | + with open(file_path, "rb") as f_split: |
| 504 | + f.write(f_split.read()) |
| 505 | + elif split == "test": |
| 506 | + url = split_config["url"] |
| 507 | + checksum = split_config["checksum"] |
| 508 | + file_path = os.path.join(root, os.path.basename(url)) |
| 509 | + utils.download_auth_url_to_file( |
| 510 | + url, file_path, self._USERNAME, self._PASSWORD, hash_prefix=checksum |
| 511 | + ) |
| 512 | + extracted_paths = extract_archive(split_path) |
| 513 | + self._convert_to_wav(root, extracted_paths) |
| 514 | + shutil.move(os.path.join(root, "aac"), os.path.join(root, "wav")) |
| 515 | + |
| 516 | + def _get_flist(self, root, meta_list_path, subset): |
| 517 | + """ |
| 518 | + Load the full list of files in the given split |
| 519 | + """ |
| 520 | + f_list = [] |
| 521 | + with open(meta_list_path, "r") as f: |
| 522 | + csv_file = csv.reader(f, delimiter=",") |
| 523 | + for line in csv_file: |
| 524 | + id, set = line[0].strip(), line[-1].strip() |
| 525 | + if set == subset: |
| 526 | + f_list += [str(i) for i in Path(root).rglob(f"{id}/**/*.wav")] |
| 527 | + return sorted(f_list) |
| 528 | + |
| 529 | + def _get_file_id(self, file_path, _ext_audio): |
| 530 | + """ |
| 531 | + Return the file identifier as a combination of speaker id, |
| 532 | + youtube video id and utterance id |
| 533 | + """ |
| 534 | + speaker_id, youtube_id, utterance_id = file_path.split("/")[-3:] |
| 535 | + utterance_id = utterance_id.replace(_ext_audio, "") |
| 536 | + file_id = "-".join([speaker_id, youtube_id, utterance_id]) |
| 537 | + return file_id |
| 538 | + |
| 539 | + def get_metadata(self, n): |
| 540 | + """ |
| 541 | + Get metadata for the n-th sample from the dataset. |
| 542 | + Returns filepath instead of waveform, but otherwise |
| 543 | + returns the same fields as `__getitem__`. |
| 544 | + """ |
| 545 | + file_path = self._flist[n] |
| 546 | + file_id = self._get_file_id(file_path, self._ext_audio) |
| 547 | + speaker_id = file_id.split("-")[0] |
| 548 | + speaker_id = int(speaker_id[3:]) |
| 549 | + return file_path, self.SAMPLE_RATE, speaker_id, file_id |
| 550 | + |
| 551 | + def __getitem__(self, n): |
| 552 | + """ |
| 553 | + Load the n-th sample from the dataset |
| 554 | + """ |
| 555 | + metadata = self.get_metadata(n) |
| 556 | + waveform, sample_rate = torchaudio.load(metadata[0], metadata[1]) |
| 557 | + if sample_rate != self.SAMPLE_RATE: |
| 558 | + raise ValueError( |
| 559 | + f"sample rate should be {self.SAMPLE_RATE}, but got {sample_rate}" |
| 560 | + ) |
| 561 | + return (waveform,) + metadata[1:] |
| 562 | + |
| 563 | + def __len__(self): |
| 564 | + return len(self._flist) |
| 565 | + |
| 566 | + |
| 567 | +class VoxCeleb2Dataset(SpeakerDataset, VoxCeleb2): |
| 568 | + """ |
| 569 | + Custom VoxCeleb2 dataset for speaker-related tasks |
| 570 | + """ |
| 571 | + |
| 572 | + def __init__(self, root, transforms=None, *args, **kwargs): |
| 573 | + if not os.path.exists(root): |
| 574 | + os.makedirs(root, exist_ok=True) |
| 575 | + kwargs["download"] = True |
| 576 | + VoxCeleb2.__init__(self, root, *args, **kwargs) |
| 577 | + SpeakerDataset.__init__(self, transforms=transforms) |
| 578 | + |
| 579 | + def get_speakers_utterances(self): |
| 580 | + speakers_utterances = defaultdict(list) |
| 581 | + for i, file_path in enumerate(self._flist): |
| 582 | + speaker_id, _, _ = file_path.split("/")[-3:] |
| 583 | + speakers_utterances[speaker_id].append(i) |
| 584 | + return speakers_utterances |
| 585 | + |
| 586 | + def get_sample(self, idx): |
| 587 | + ( |
| 588 | + waveform, |
| 589 | + sample_rate, |
| 590 | + speaker, |
| 591 | + _, |
| 592 | + ) = VoxCeleb2.__getitem__(self, idx) |
| 593 | + return waveform, sample_rate, speaker |
| 594 | + |
| 595 | + def get_path(self, idx): |
| 596 | + return self._flist[idx] |
0 commit comments