Skip to content

Commit c2bb3fe

Browse files
MuyangDunv-kkudrynski
authored andcommitted
[FastPitch/PyT] Add mixed English and Mandarin bilingual support
1 parent 84be38e commit c2bb3fe

22 files changed

Lines changed: 1045 additions & 73 deletions

PyTorch/SpeechSynthesis/FastPitch/README.md

Lines changed: 174 additions & 50 deletions
Large diffs are not rendered by default.
Binary file not shown.

PyTorch/SpeechSynthesis/FastPitch/common/text/symbols.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,26 @@ def get_symbols(symbol_set='english_basic'):
3131
_accented = 'áçéêëñöøćž'
3232
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
3333
symbols = list(_punctuation + _math + _special + _accented + _letters) + _arpabet
34+
elif symbol_set == 'english_mandarin_basic':
35+
from .zh.chinese import chinese_punctuations, valid_symbols as mandarin_valid_symbols
36+
37+
# Prepend "#" to mandarin phonemes to ensure uniqueness (some are the same as uppercase letters):
38+
_mandarin_phonemes = ['#' + s for s in mandarin_valid_symbols]
39+
40+
_pad = '_'
41+
_punctuation = '!\'(),.:;? '
42+
_chinese_punctuation = ["#" + p for p in chinese_punctuations]
43+
_special = '-'
44+
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
45+
symbols = list(_pad + _special + _punctuation + _letters) + _arpabet + _mandarin_phonemes + _chinese_punctuation
3446
else:
3547
raise Exception("{} symbol set does not exist".format(symbol_set))
3648

3749
return symbols
3850

3951

4052
def get_pad_idx(symbol_set='english_basic'):
41-
if symbol_set in {'english_basic', 'english_basic_lowercase'}:
53+
if symbol_set in {'english_basic', 'english_basic_lowercase', 'english_mandarin_basic'}:
4254
return 0
4355
else:
4456
raise Exception("{} symbol set not used yet".format(symbol_set))

PyTorch/SpeechSynthesis/FastPitch/common/text/text_processing.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,13 @@ def encode_text(self, text, return_all=False):
162162
return text_encoded, text_clean, text_arpabet
163163

164164
return text_encoded
165+
166+
167+
def get_text_processing(symbol_set, text_cleaners, p_arpabet):
168+
if symbol_set in ['english_basic', 'english_basic_lowercase', 'english_expanded']:
169+
return TextProcessing(symbol_set, text_cleaners, p_arpabet=p_arpabet)
170+
elif symbol_set == 'english_mandarin_basic':
171+
from common.text.zh.mandarin_text_processing import MandarinTextProcessing
172+
return MandarinTextProcessing(symbol_set, text_cleaners, p_arpabet=p_arpabet)
173+
else:
174+
raise ValueError(f"No TextProcessing for symbol set {symbol_set} unknown.")
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of the NVIDIA CORPORATION nor the
12+
# names of its contributors may be used to endorse or promote products
13+
# derived from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16+
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17+
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18+
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19+
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20+
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21+
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22+
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24+
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
#
26+
# *****************************************************************************
27+
28+
import re
29+
30+
from pypinyin import lazy_pinyin, Style
31+
32+
33+
valid_symbols = ['^', 'A', 'AI', 'AN', 'ANG', 'AO', 'B', 'C', 'CH', 'D',
34+
'E', 'EI', 'EN', 'ENG', 'ER', 'F', 'G', 'H', 'I', 'IE',
35+
'IN', 'ING', 'IU', 'J', 'K', 'L', 'M', 'N', 'O', 'ONG',
36+
'OU', 'P', 'Q', 'R', 'S', 'SH', 'T', 'U', 'UI', 'UN',
37+
'V', 'VE', 'VN', 'W', 'X', 'Y', 'Z', 'ZH']
38+
tones = ['1', '2', '3', '4', '5']
39+
chinese_punctuations = ",。?!;:、‘’“”()【】「」《》"
40+
valid_symbols += tones
41+
42+
43+
def load_pinyin_dict(path="common/text/zh/pinyin_dict.txt"):
44+
with open(path) as f:
45+
return {l.split()[0]: l.split()[1:] for l in f}
46+
47+
pinyin_dict = load_pinyin_dict()
48+
49+
50+
def is_chinese(text):
51+
return u'\u4e00' <= text[0] <= u'\u9fff' or text[0] in chinese_punctuations
52+
53+
54+
def split_text(text):
55+
regex = r'([\u4e00-\u9fff' + chinese_punctuations + ']+)'
56+
return re.split(regex, text)
57+
58+
59+
def chinese_text_to_symbols(text):
60+
symbols = []
61+
phonemes_and_tones = ""
62+
63+
# convert text to mandarin pinyin sequence
64+
# ignore polyphonic words as it has little effect on training
65+
pinyin_seq = lazy_pinyin(text, style=Style.TONE3)
66+
67+
for item in pinyin_seq:
68+
if item in chinese_punctuations:
69+
symbols += [item]
70+
phonemes_and_tones += ' ' + item
71+
continue
72+
if not item[-1].isdigit():
73+
item += '5'
74+
item, tone = item[:-1], item[-1]
75+
phonemes = pinyin_dict[item.upper()]
76+
symbols += phonemes
77+
symbols += [tone]
78+
79+
phonemes_and_tones += '{' + ' '.join(phonemes + [tone]) + '}'
80+
81+
return symbols, phonemes_and_tones
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import re
2+
import numpy as np
3+
from .chinese import split_text, is_chinese, chinese_text_to_symbols
4+
from ..text_processing import TextProcessing
5+
6+
7+
class MandarinTextProcessing(TextProcessing):
8+
def __init__(self, symbol_set, cleaner_names, p_arpabet=0.0,
9+
handle_arpabet='word', handle_arpabet_ambiguous='ignore',
10+
expand_currency=True):
11+
12+
super().__init__(symbol_set, cleaner_names, p_arpabet, handle_arpabet,
13+
handle_arpabet_ambiguous, expand_currency)
14+
15+
16+
def sequence_to_text(self, sequence):
17+
result = ''
18+
19+
tmp = ''
20+
for symbol_id in sequence:
21+
if symbol_id in self.id_to_symbol:
22+
s = self.id_to_symbol[symbol_id]
23+
# Enclose ARPAbet and mandarin phonemes back in curly braces:
24+
if len(s) > 1 and s[0] == '@':
25+
s = '{%s}' % s[1:]
26+
result += s
27+
elif len(s) > 1 and s[0] == '#' and s[1].isdigit(): # mandarin tone
28+
tmp += s[1] + '} '
29+
result += tmp
30+
tmp = ''
31+
elif len(s) > 1 and s[0] == '#' and (s[1].isalpha() or s[1] == '^'): # mandarin phoneme
32+
if tmp == '':
33+
tmp += ' {' + s[1:] + ' '
34+
else:
35+
tmp += s[1:] + ' '
36+
elif len(s) > 1 and s[0] == '#': # chinese punctuation
37+
s = s[1]
38+
result += s
39+
else:
40+
result += s
41+
42+
return result.replace('}{', ' ').replace(' ', ' ')
43+
44+
45+
def chinese_symbols_to_sequence(self, symbols):
46+
return self.symbols_to_sequence(['#' + s for s in symbols])
47+
48+
49+
def encode_text(self, text, return_all=False):
50+
# split the text into English and Chinese segments
51+
segments = [segment for segment in split_text(text) if segment != ""]
52+
53+
text_encoded = []
54+
text_clean = ""
55+
text_arpabet = ""
56+
57+
for segment in segments:
58+
if is_chinese(segment[0]): # process the Chinese segment
59+
chinese_symbols, segment_arpabet = chinese_text_to_symbols(segment)
60+
segment_encoded = self.chinese_symbols_to_sequence(chinese_symbols)
61+
segment_clean = segment
62+
segment_encoded = segment_encoded
63+
else: # process the English segment
64+
segment_encoded, segment_clean, segment_arpabet = \
65+
super().encode_text(segment, return_all=True)
66+
67+
text_encoded += segment_encoded
68+
text_clean += segment_clean
69+
text_arpabet += segment_arpabet
70+
71+
if return_all:
72+
return text_encoded, text_clean, text_arpabet
73+
74+
return text_encoded

0 commit comments

Comments
 (0)