1
1
import logging
2
2
import math
3
3
import os .path
4
+ import random
4
5
import re
5
6
from typing import List
6
7
7
8
import librosa
8
9
import numpy as np
9
10
import torch
10
- from time import time as ttime
11
11
12
12
from contants import config
13
13
from gpt_sovits .AR .models .t2s_lightning_module import Text2SemanticLightningModule
14
14
from gpt_sovits .module .mel_processing import spectrogram_torch
15
15
from gpt_sovits .module .models import SynthesizerTrn
16
- from gpt_sovits .utils import DictToAttrRecursive
17
16
from gpt_sovits .text import cleaned_text_to_sequence
18
17
from gpt_sovits .text .cleaner import clean_text
18
+ from gpt_sovits .utils import DictToAttrRecursive
19
19
from utils .classify_language import classify_language
20
20
from utils .data_utils import check_is_none
21
21
from utils .sentence import split_languages , sentence_split
@@ -120,6 +120,25 @@ def load_gpt(self, gpt_path):
120
120
total = sum ([param .nelement () for param in self .t2s_model .parameters ()])
121
121
logging .info (f"Number of parameter: { total / 1e6 :.2f} M" )
122
122
123
+ def set_seed (self , seed : int ):
124
+ seed = int (seed )
125
+ seed = seed if seed != - 1 else random .randrange (1 << 32 )
126
+ logging .debug (f"Set seed to { seed } " )
127
+ os .environ ['PYTHONHASHSEED' ] = str (seed )
128
+ random .seed (seed )
129
+ np .random .seed (seed )
130
+ torch .manual_seed (seed )
131
+ try :
132
+ if torch .cuda .is_available ():
133
+ torch .cuda .manual_seed (seed )
134
+ torch .cuda .manual_seed_all (seed )
135
+ # torch.backends.cudnn.deterministic = True
136
+ # torch.backends.cudnn.benchmark = False
137
+ # torch.backends.cudnn.enabled = True
138
+ except :
139
+ pass
140
+ return seed
141
+
123
142
def get_speakers (self ):
124
143
return self .speakers
125
144
@@ -165,20 +184,21 @@ def get_bert_feature(self, text, phones, word2ph, language):
165
184
def get_bert_and_cleaned_text_multilang (self , text : list ):
166
185
sentences = split_languages (text , expand_abbreviations = True , expand_hyphens = True )
167
186
168
- phones , word2ph , norm_text , bert = [], [], [], []
187
+ phones_list , word2ph_list , norm_text_list , bert_list = [], [], [], []
169
188
170
189
for sentence , lang in sentences :
171
- _phones , _word2ph , _norm_text = self .get_cleaned_text (sentence , lang )
172
- _bert = self .get_bert_feature (sentence , _phones , _word2ph , _norm_text )
173
- phones .extend (_phones )
174
- if _word2ph is not None :
175
- word2ph .extend (_word2ph )
176
- norm_text .extend (_norm_text )
177
- bert .append (_bert )
190
+ phones , word2ph , _norm_text = self .get_cleaned_text (sentence , lang )
191
+ bert = self .get_bert_feature (sentence , phones , word2ph , _norm_text )
192
+ phones_list .extend (phones )
193
+ if word2ph is not None :
194
+ word2ph_list .extend (word2ph )
195
+ norm_text_list .extend (_norm_text )
196
+ bert_list .append (bert )
178
197
179
- bert = torch .cat (bert , dim = 1 ).to (self .device , dtype = self .torch_dtype )
198
+ norm_text = '' .join (norm_text_list )
199
+ bert = torch .cat (bert_list , dim = 1 ).to (self .device , dtype = self .torch_dtype )
180
200
181
- return phones , word2ph , norm_text , bert
201
+ return phones_list , word2ph_list , norm_text , bert
182
202
183
203
def get_spepc (self , audio , orig_sr ):
184
204
"""audio的sampling_rate与模型相同"""
@@ -238,6 +258,11 @@ def preprocess_text(self, text: str, lang: str, segment_size: int):
238
258
239
259
result = []
240
260
for text in texts :
261
+ text = text .strip ("\n " )
262
+ if (text [0 ] not in splits and len (self .get_first (text )) < 4 ):
263
+ text = "。" + text if lang != "en" else "." + text
264
+ if (text [- 1 ] not in splits ):
265
+ text += "。" if lang != "en" else "."
241
266
phones , word2ph , norm_text , bert_features = self .get_bert_and_cleaned_text_multilang (text )
242
267
res = {
243
268
"phones" : phones ,
@@ -251,7 +276,7 @@ def preprocess_prompt(self, reference_audio, reference_audio_sr, prompt_text: st
251
276
if self .prompt_cache .get ("prompt_text" ) != prompt_text :
252
277
if prompt_lang .lower () == "auto" :
253
278
prompt_lang = classify_language (prompt_text )
254
-
279
+ prompt_text = prompt_text . strip ( " \n " )
255
280
if (prompt_text [- 1 ] not in splits ):
256
281
prompt_text += "。" if prompt_lang != "en" else "."
257
282
phones , word2ph , norm_text = self .get_cleaned_text (prompt_text , prompt_lang )
@@ -438,9 +463,11 @@ def speed_change(self, input_audio: np.ndarray, speed_factor: float, sr: int):
438
463
439
464
def infer (self , text , lang , reference_audio , reference_audio_sr , prompt_text , prompt_lang , top_k , top_p ,
440
465
temperature , batch_size : int = 5 , batch_threshold : float = 0.75 , split_bucket : bool = True ,
441
- return_fragment : bool = False , speed_factor : float = 1.0 ,
466
+ return_fragment : bool = False , speed_factor : float = 1.0 , seed : int = - 1 ,
442
467
segment_size : int = config .gpt_sovits_config .segment_size , ** kwargs ):
443
468
469
+ self .set_seed (seed )
470
+
444
471
if return_fragment :
445
472
split_bucket = False
446
473
@@ -476,7 +503,7 @@ def infer(self, text, lang, reference_audio, reference_audio_sr, prompt_text, pr
476
503
if self .is_half :
477
504
all_bert_features = all_bert_features .half ()
478
505
479
- logging .debug (f"Infer text:{ [ '' . join ( text ) for text in norm_text ] } " )
506
+ logging .debug (f"Infer text:{ norm_text } " )
480
507
if no_prompt_text :
481
508
prompt = None
482
509
else :
0 commit comments