Skip to content

Commit 22e66ac

Browse files
committed
add training metrics
1 parent bebd310 commit 22e66ac

25 files changed

+68
-13
lines changed

README.md

+35-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# TitaNet
22

33
<p align="center">
4-
<img src="assets/titanet-architecture.png" alt="titanet-architecture" style="width: 450px;"/>
4+
<img src="assets/models/titanet-architecture.png" alt="titanet-architecture" style="width: 450px;"/>
55
</p>
66

77
This repository contains a small scale implementation of the following paper:
@@ -40,7 +40,35 @@ python3 src/train.py -p "./parameters.yml"
4040

4141
Training and evaluation metrics, along with model checkpoints and results, are directly logged into a W&B project, which is openly accessible [here](https://wandb.ai/wadaboa/titanet). In case you want to perform a custom training run, you have to either disable W&B (see `parameters.yml`) or provide your own entity (your username), project and API key file location in the `parameters.yml` file. The W&B API key file is a plain text file that contains a single line with your W&B API key, that you can get from [here](https://wandb.ai/authorize).
4242

43-
## Results
43+
## Training & validation
44+
45+
This section shows training and validation metrics observed for around 75 epochs. In case you want to see more metrics, please head over to the [W&B project](https://wandb.ai/wadaboa/titanet).
46+
47+
### Baseline CE vs TitaNet CE
48+
49+
This experiment compares training and validation loss and accuracy of the baseline and TitaNet models trained with cross-entropy loss. As we can see, training metrics reach similar values, while validation metrics are much better with TitaNet. Moreover, plots suggest that the baseline model had a slight overfitting problem.
50+
51+
Training Loss | Training Accuracy
52+
:-------------------------:|:-------------------------:
53+
![](assets/training/baseline-titanet-ce-train-loss.png) | ![](assets/training/baseline-titanet-ce-train-accuracy.png)
54+
55+
Validation Loss | Validation Accuracy
56+
:-------------------------:|:-------------------------:
57+
![](assets/training/baseline-titanet-ce-val-loss.png) | ![](assets/training/baseline-titanet-ce-val-accuracy.png)
58+
59+
### TitaNet CE vs TitaNet ArcFace
60+
61+
This experiment compares training and validation loss and accuracy of two TitaNet models (model size "s"), trained with cross-entropy and ArcFace loss. The ArcFace parameters (scale and margin) are the ones specified in the original paper (30 and 0.2). As we can see, metrics are quite similar and no major differences can be observed.
62+
63+
Training Loss | Training Accuracy
64+
:-------------------------:|:-------------------------:
65+
![](assets/training/titanet-ce-arc-train-loss.png) | ![](assets/training/titanet-ce-arc-train-accuracy.png)
66+
67+
Validation Loss | Validation Accuracy
68+
:-------------------------:|:-------------------------:
69+
![](assets/training/titanet-ce-arc-val-loss.png) | ![](assets/training/titanet-ce-arc-val-accuracy.png)
70+
71+
## Visualizations
4472

4573
This section shows some visual results obtained after training each embedding model for around 75 epochs. Please note that all figures represent the same set of utterances, even though different figures use different colours for the same speaker.
4674

@@ -50,28 +78,28 @@ This test compares the baseline and TitaNet models on the LibriSpeech dataset us
5078

5179
Baseline | TitaNet
5280
:-------------------------:|:-------------------------:
53-
![](results/ls-baseline-ce-umap.png) | ![](results/ls-titanet-ce-umap.png)
81+
![](assets/visualization/ls-baseline-ce-umap.png) | ![](assets/visualization/ls-titanet-ce-umap.png)
5482

5583
### Baseline vs TitaNet on VCTK
5684

5785
This test compares the baseline and TitaNet models on the VCTK dataset, unseen during training. Both models were trained with cross-entropy loss and 2D projections were performed with UMAP. As above, TitaNet beats the baseline model by a large margin.
5886

5987
Baseline | TitaNet
6088
:-------------------------:|:-------------------------:
61-
![](results/vctk-baseline-ce-umap.png) | ![](results/vctk-titanet-ce-umap.png)
89+
![](assets/visualization/vctk-baseline-ce-umap.png) | ![](assets/visualization/vctk-titanet-ce-umap.png)
6290

6391
### SVD vs UMAP reduction
6492

6593
This test compares two 2D reduction methods, namely SVD and UMAP. Both figures rely on the TitaNet model trained with cross-entropy loss. As we can see, the choice of the reduction method highly influences our subjective evaluation, with UMAP giving much better separation in the latent space.
6694

67-
TitaNet LS SVD | TitaNet LS UMAP
95+
SVD | UMAP
6896
:-------------------------:|:-------------------------:
69-
![](results/ls-titanet-ce-svd.png) | ![](results/ls-titanet-ce-umap.png)
97+
![](assets/visualization/ls-titanet-ce-svd.png) | ![](assets/visualization/ls-titanet-ce-umap.png)
7098

7199
### Cross-entropy vs ArcFace loss
72100

73101
This test compares two TitaNet models, one trained with cross-entropy loss and the other one trained with ArcFace loss. Both figures rely on UMAP as their 2D reduction method. As we can see, there doesn't seem to be a winner in this example, as both models are able to obtain good clustering properties.
74102

75103
Cross-entropy | ArcFace
76104
:-------------------------:|:-------------------------:
77-
![](results/ls-titanet-ce-umap.png) | ![](results/ls-titanet-arc-umap.png)
105+
![](assets/visualization/ls-titanet-ce-umap.png) | ![](assets/visualization/ls-titanet-arc-umap.png)
File renamed without changes.
File renamed without changes.
Loading
Loading
Loading
Loading
Loading
369 KB
Loading
Loading
414 KB
Loading
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

src/learn.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ def evaluate(
410410
def test(
411411
model,
412412
test_dataset,
413+
indices=None,
413414
wandb_run=None,
414415
log_console=True,
415416
mindcf_p_target=0.01,
@@ -425,7 +426,7 @@ def test(
425426

426427
# Get cosine similarity scores and labels
427428
samples = (
428-
test_dataset.get_sample_pairs(device=device)
429+
test_dataset.get_sample_pairs(indices=indices, device=device)
429430
if not isinstance(test_dataset, torch.utils.data.Subset)
430431
else test_dataset.dataset.get_sample_pairs(
431432
indices=test_dataset.indices, device=device
@@ -455,6 +456,8 @@ def test(
455456
if wandb_run is not None:
456457
wandb_run.notes = json.dumps(metrics, indent=2).encode("utf-8")
457458

459+
return metrics
460+
458461

459462
def infer(
460463
model,

titanet.ipynb

+29-5
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@
197197
{
198198
"cell_type": "code",
199199
"execution_count": 49,
200-
"id": "92f4d67d",
200+
"id": "d9dcaf4a",
201201
"metadata": {},
202202
"outputs": [],
203203
"source": [
@@ -594,7 +594,7 @@
594594
"id": "7c2b36b1",
595595
"metadata": {},
596596
"source": [
597-
"<img src=\"assets/d-vector.png\" alt=\"d-vector\" style=\"width: 450px;\"/>\n",
597+
"<img src=\"assets/models/d-vector.png\" alt=\"d-vector\" style=\"width: 450px;\"/>\n",
598598
" \n",
599599
"Our baseline model is based on the d-vector concept. A d-vector is simply a way to refer to speaker embeddings generated by a DNN (Deep Neural Network), hence the \"d\" prefix. The standard way to compute such d-vectors, as described in [Generalized End-to-End Loss for Speaker Verification](https://arxiv.org/abs/1710.10467), is through a stack of LSTM layers processing spectrogram segments. In particular, the full spectrogram of shape $B\\times M\\times T$ is unfolded in a sequence of tensors of shape $B\\times M \\times S$, where $S$ is the segment length. Then, each segment is fed into a recurrent module and hidden states are collapsed in a single dimension by either averaging or simply taking the last one. Collapsed vectors are then projected onto the embedding size and once we have one embedding vector for each segment, the embedding vector of the full spectrogram is just the average of all its constituent segments' embeddings."
600600
]
@@ -628,7 +628,7 @@
628628
"id": "1b9a37c4",
629629
"metadata": {},
630630
"source": [
631-
"<img src=\"assets/titanet-architecture.png\" alt=\"titanet-architecture\" style=\"width: 450px;\"/>"
631+
"<img src=\"assets/models/titanet-architecture.png\" alt=\"titanet-architecture\" style=\"width: 450px;\"/>"
632632
]
633633
},
634634
{
@@ -1137,7 +1137,7 @@
11371137
{
11381138
"cell_type": "code",
11391139
"execution_count": 83,
1140-
"id": "864d1f1a",
1140+
"id": "a0e5fb74",
11411141
"metadata": {},
11421142
"outputs": [
11431143
{
@@ -1506,6 +1506,30 @@
15061506
")"
15071507
]
15081508
},
1509+
{
1510+
"cell_type": "code",
1511+
"execution_count": null,
1512+
"id": "235d9152",
1513+
"metadata": {},
1514+
"outputs": [
1515+
{
1516+
"name": "stderr",
1517+
"output_type": "stream",
1518+
"text": [
1519+
"Loading sample pairs: 659it [00:10, 63.73it/s]"
1520+
]
1521+
}
1522+
],
1523+
"source": [
1524+
"learn.test(\n",
1525+
" titanet_model, \n",
1526+
" ls_dataset, \n",
1527+
" indices=ls_utterances, \n",
1528+
" log_console=False,\n",
1529+
" device=device\n",
1530+
")"
1531+
]
1532+
},
15091533
{
15101534
"cell_type": "markdown",
15111535
"id": "ddbdd0a1",
@@ -1762,7 +1786,7 @@
17621786
{
17631787
"cell_type": "code",
17641788
"execution_count": null,
1765-
"id": "f486921c",
1789+
"id": "88126c4a",
17661790
"metadata": {},
17671791
"outputs": [],
17681792
"source": []

0 commit comments

Comments
 (0)