Skip to content

Commit 5aa3e92

Browse files
committed
Update seed
Fix text
1 parent 96f3d51 commit 5aa3e92

File tree

5 files changed

+79
-36
lines changed

5 files changed

+79
-36
lines changed

config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,14 @@
66
After the initial launch that automatically generates the config.yaml file, any modifications to the configuration should be made directly in the config.yaml file, not in the config.py file.
77
"""
88

9-
import copy
109
import logging
1110
import os
1211
import secrets
1312
import shutil
1413
import string
1514
import sys
1615
import traceback
17-
from dataclasses import dataclass, field, asdict, fields, is_dataclass
16+
from dataclasses import dataclass, field, fields, is_dataclass
1817
from typing import List, Union, Optional, Dict
1918

2019
import torch
@@ -183,6 +182,7 @@ class GPTSoVitsConfig(AsDictMixin):
183182
use_streaming: bool = False
184183
batch_size: int = 5
185184
speed: float = 1.0
185+
seed: int = -1
186186
presets: Dict[str, GPTSoVitsPreset] = field(default_factory=lambda: {"default": GPTSoVitsPreset(),
187187
"default2": GPTSoVitsPreset()})
188188

gpt_sovits/gpt_sovits.py

+42-15
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
import logging
22
import math
33
import os.path
4+
import random
45
import re
56
from typing import List
67

78
import librosa
89
import numpy as np
910
import torch
10-
from time import time as ttime
1111

1212
from contants import config
1313
from gpt_sovits.AR.models.t2s_lightning_module import Text2SemanticLightningModule
1414
from gpt_sovits.module.mel_processing import spectrogram_torch
1515
from gpt_sovits.module.models import SynthesizerTrn
16-
from gpt_sovits.utils import DictToAttrRecursive
1716
from gpt_sovits.text import cleaned_text_to_sequence
1817
from gpt_sovits.text.cleaner import clean_text
18+
from gpt_sovits.utils import DictToAttrRecursive
1919
from utils.classify_language import classify_language
2020
from utils.data_utils import check_is_none
2121
from utils.sentence import split_languages, sentence_split
@@ -120,6 +120,25 @@ def load_gpt(self, gpt_path):
120120
total = sum([param.nelement() for param in self.t2s_model.parameters()])
121121
logging.info(f"Number of parameter: {total / 1e6:.2f}M")
122122

123+
def set_seed(self, seed: int):
124+
seed = int(seed)
125+
seed = seed if seed != -1 else random.randrange(1 << 32)
126+
logging.debug(f"Set seed to {seed}")
127+
os.environ['PYTHONHASHSEED'] = str(seed)
128+
random.seed(seed)
129+
np.random.seed(seed)
130+
torch.manual_seed(seed)
131+
try:
132+
if torch.cuda.is_available():
133+
torch.cuda.manual_seed(seed)
134+
torch.cuda.manual_seed_all(seed)
135+
# torch.backends.cudnn.deterministic = True
136+
# torch.backends.cudnn.benchmark = False
137+
# torch.backends.cudnn.enabled = True
138+
except:
139+
pass
140+
return seed
141+
123142
def get_speakers(self):
124143
return self.speakers
125144

@@ -165,20 +184,21 @@ def get_bert_feature(self, text, phones, word2ph, language):
165184
def get_bert_and_cleaned_text_multilang(self, text: list):
166185
sentences = split_languages(text, expand_abbreviations=True, expand_hyphens=True)
167186

168-
phones, word2ph, norm_text, bert = [], [], [], []
187+
phones_list, word2ph_list, norm_text_list, bert_list = [], [], [], []
169188

170189
for sentence, lang in sentences:
171-
_phones, _word2ph, _norm_text = self.get_cleaned_text(sentence, lang)
172-
_bert = self.get_bert_feature(sentence, _phones, _word2ph, _norm_text)
173-
phones.extend(_phones)
174-
if _word2ph is not None:
175-
word2ph.extend(_word2ph)
176-
norm_text.extend(_norm_text)
177-
bert.append(_bert)
190+
phones, word2ph, _norm_text = self.get_cleaned_text(sentence, lang)
191+
bert = self.get_bert_feature(sentence, phones, word2ph, _norm_text)
192+
phones_list.extend(phones)
193+
if word2ph is not None:
194+
word2ph_list.extend(word2ph)
195+
norm_text_list.extend(_norm_text)
196+
bert_list.append(bert)
178197

179-
bert = torch.cat(bert, dim=1).to(self.device, dtype=self.torch_dtype)
198+
norm_text = ''.join(norm_text_list)
199+
bert = torch.cat(bert_list, dim=1).to(self.device, dtype=self.torch_dtype)
180200

181-
return phones, word2ph, norm_text, bert
201+
return phones_list, word2ph_list, norm_text, bert
182202

183203
def get_spepc(self, audio, orig_sr):
184204
"""audio的sampling_rate与模型相同"""
@@ -238,6 +258,11 @@ def preprocess_text(self, text: str, lang: str, segment_size: int):
238258

239259
result = []
240260
for text in texts:
261+
text = text.strip("\n")
262+
if (text[0] not in splits and len(self.get_first(text)) < 4):
263+
text = "。" + text if lang != "en" else "." + text
264+
if (text[-1] not in splits):
265+
text += "。" if lang != "en" else "."
241266
phones, word2ph, norm_text, bert_features = self.get_bert_and_cleaned_text_multilang(text)
242267
res = {
243268
"phones": phones,
@@ -251,7 +276,7 @@ def preprocess_prompt(self, reference_audio, reference_audio_sr, prompt_text: st
251276
if self.prompt_cache.get("prompt_text") != prompt_text:
252277
if prompt_lang.lower() == "auto":
253278
prompt_lang = classify_language(prompt_text)
254-
279+
prompt_text = prompt_text.strip("\n")
255280
if (prompt_text[-1] not in splits):
256281
prompt_text += "。" if prompt_lang != "en" else "."
257282
phones, word2ph, norm_text = self.get_cleaned_text(prompt_text, prompt_lang)
@@ -438,9 +463,11 @@ def speed_change(self, input_audio: np.ndarray, speed_factor: float, sr: int):
438463

439464
def infer(self, text, lang, reference_audio, reference_audio_sr, prompt_text, prompt_lang, top_k, top_p,
440465
temperature, batch_size: int = 5, batch_threshold: float = 0.75, split_bucket: bool = True,
441-
return_fragment: bool = False, speed_factor: float = 1.0,
466+
return_fragment: bool = False, speed_factor: float = 1.0, seed: int = -1,
442467
segment_size: int = config.gpt_sovits_config.segment_size, **kwargs):
443468

469+
self.set_seed(seed)
470+
444471
if return_fragment:
445472
split_bucket = False
446473

@@ -476,7 +503,7 @@ def infer(self, text, lang, reference_audio, reference_audio_sr, prompt_text, pr
476503
if self.is_half:
477504
all_bert_features = all_bert_features.half()
478505

479-
logging.debug(f"Infer text:{[''.join(text) for text in norm_text]}")
506+
logging.debug(f"Infer text:{norm_text}")
480507
if no_prompt_text:
481508
prompt = None
482509
else:

tts_app/static/js/index.js

+10
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ function getLink() {
8282
let temperature = null;
8383
let batch_size = null;
8484
let speed = null;
85+
let seed = null;
8586

8687
if (currentModelPage == 1 || currentModelPage == 2 || currentModelPage == 3) {
8788
length = document.getElementById("input_length" + currentModelPage).value;
@@ -112,6 +113,7 @@ function getLink() {
112113
top_p = document.getElementById('input_top_p4').value;
113114
temperature = document.getElementById('input_temperature4').value;
114115
batch_size = document.getElementById('input_batch_size4').value;
116+
seed = document.getElementById('input_seed4').value;
115117
// speed = document.getElementById('input_speed4').value;
116118
url += "/voice/gpt-sovits?id=" + id;
117119

@@ -182,6 +184,8 @@ function getLink() {
182184
url += "&batch_size=" + batch_size;
183185
if (speed !== null && speed !== "")
184186
url += "&speed=" + speed;
187+
if (seed !== null && seed !== "")
188+
url += "&seed=" + seed;
185189
}
186190

187191
if (api_key != "") {
@@ -273,6 +277,7 @@ function setAudioSourceByPost() {
273277
let temperature = null;
274278
let batch_size = null;
275279
let speed = null;
280+
let seed = null;
276281

277282
let headers = {};
278283

@@ -313,6 +318,8 @@ function setAudioSourceByPost() {
313318
temperature = $("#input_temperature4").val();
314319
batch_size = $("#input_batch_size4").val();
315320
// speed = $("#input_speed4").val();
321+
seed = $("#input_seed4").val();
322+
316323
}
317324

318325

@@ -375,6 +382,9 @@ function setAudioSourceByPost() {
375382
if (currentModelPage == 4 && speed) {
376383
formData.append('speed', speed);
377384
}
385+
if (currentModelPage == 4 && seed) {
386+
formData.append('seed', seed);
387+
}
378388

379389
let downloadButton = document.getElementById("downloadButton" + currentModelPage);
380390

tts_app/templates/pages/gpt_sovits.html

+20-12
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
</form>
2323
<form class="w-100">
2424
<div class="row">
25-
<div class="col-md-4 mb-3">
25+
<div class="col-md-3 mb-3">
2626
<label data-toggle="tooltip" data-placement="top"
2727
title="默认为wav">format</label>
2828
<select class="form-control input_format" id="input_format4" oninput="updateLink()">
@@ -34,22 +34,29 @@
3434
<option>flac</option>
3535
</select>
3636
</div>
37-
<div class="col-md-4 mb-3">
37+
<div class="col-md-3 mb-3">
3838
<label data-toggle="tooltip" data-placement="top"
3939
title="自动识别语言auto:可识别的语言根据不同speaker而不同,方言无法自动识别。方言模型需要手动指定语言,比如粤语Cantonese要指定参数lang=gd">lang</label>
4040
<input type="text" class="form-control input_lang" id="input_lang4" oninput="updateLink()"
4141
value=""
4242
placeholder="auto"/>
4343
</div>
44-
<div class="col-md-4 mb-3">
44+
<div class="col-md-3 mb-3">
4545
<label data-toggle="tooltip" data-placement="top"
4646
title="按标点符号分段,加起来大于segment_size时为一段文本。segment_size<=0表示不分段。">segment_size</label>
4747
<input type="number" class="form-control input_segment_size" id="input_segment_size4"
4848
oninput="updateLink()"
4949
value=""
5050
placeholder="50" step="1"/>
5151
</div>
52-
52+
<div class="col-md-3 mb-3">
53+
<label for="seed" data-toggle="tooltip" data-placement="top"
54+
title="随机种子">seed</label>
55+
<input type="text" class="form-control seed" id="input_seed4"
56+
oninput="updateLink()"
57+
value=""
58+
placeholder="5"/>
59+
</div>
5360
</div>
5461

5562
<div class="row">
@@ -85,14 +92,15 @@
8592
value=""
8693
placeholder="5"/>
8794
</div>
88-
{# <div class="col-md-2 mb-3">#}
89-
{# <label for="speed" data-toggle="tooltip" data-placement="top"#}
90-
{# title="">speed</label>#}
91-
{# <input type="text" class="form-control speed" id="input_speed4"#}
92-
{# oninput="updateLink()"#}
93-
{# value=""#}
94-
{# placeholder="1.0"/>#}
95-
{# </div>#}
95+
96+
{# <div class="col-md-2 mb-3">#}
97+
{# <label for="speed" data-toggle="tooltip" data-placement="top"#}
98+
{# title="">speed</label>#}
99+
{# <input type="text" class="form-control speed" id="input_speed4"#}
100+
{# oninput="updateLink()"#}
101+
{# value=""#}
102+
{# placeholder="1.0"/>#}
103+
{# </div>#}
96104

97105
</div>
98106
<div class="row">

tts_app/voice_api/views.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,18 @@
11
import copy
2-
import logging
32
import os
43
import time
5-
import traceback
64
import uuid
75
from io import BytesIO
86

9-
import librosa
10-
import numpy as np
117
from flask import request, jsonify, make_response, send_file, Blueprint
128
from werkzeug.utils import secure_filename
139

10+
from contants import ModelType
1411
from contants import config
1512
# from gpt_sovits.utils import load_audio
1613
from logger import logger
17-
from contants import ModelType
18-
from tts_app.voice_api.auth import require_api_key
1914
from tts_app.model_manager import model_manager, tts_manager
15+
from tts_app.voice_api.auth import require_api_key
2016
from tts_app.voice_api.utils import *
2117
from utils.data_utils import check_is_none
2218

@@ -586,6 +582,7 @@ def voice_gpt_sovits_api():
586582
use_streaming = get_param(request_data, 'streaming', config.gpt_sovits_config.use_streaming, bool)
587583
batch_size = get_param(request_data, 'batch_size', config.gpt_sovits_config.batch_size, int)
588584
speed_factor = get_param(request_data, 'speed', config.gpt_sovits_config.speed, float)
585+
seed = get_param(request_data, 'seed', config.gpt_sovits_config.seed, int)
589586
except Exception as e:
590587
logger.error(f"[{ModelType.GPT_SOVITS.value}] {e}")
591588
return make_response("parameter error", 400)
@@ -643,7 +640,8 @@ def voice_gpt_sovits_api():
643640
"temperature": temperature,
644641
"preset": preset,
645642
"batch_size": batch_size,
646-
"speed_factor": speed_factor
643+
"speed_factor": speed_factor,
644+
"seed": seed
647645
}
648646

649647
if use_streaming:

0 commit comments

Comments
 (0)