1616from tn .processor import Processor
1717from tn .utils import get_abs_path
1818
19- from pynini import string_file , accep , cross
19+ from pynini import string_file , accep , cross , union
2020from pynini .lib .pynutil import delete , insert , add_weight
2121
2222
@@ -36,6 +36,11 @@ def build_tagger(self):
3636 get_abs_path ('../itn/chinese/data/measure/units_zh.tsv' ))
3737 sign = string_file (
3838 get_abs_path ('../itn/chinese/data/number/sign.tsv' )) # + -
39+ digit = string_file (
40+ get_abs_path ('../itn/chinese/data/number/digit.tsv' )) # 1 ~ 9
41+ digit_zh = string_file (
42+ get_abs_path ('../itn/chinese/data/number/digit_zh.tsv' )) # 1 ~ 9
43+ addzero = insert ('0' )
3944 to = cross ('到' , '~' ) | cross ('到百分之' , '~' )
4045
4146 units = add_weight (
@@ -55,8 +60,35 @@ def build_tagger(self):
5560
5661 # 十千米每小时 => 10km/h, 十一到一百千米每小时 => 11~100km/h
5762 measure = number + (to + number ).ques + units
58- tagger = insert ('value: "' ) + (measure | percent ) + insert ('"' )
5963
64+ # XXX: 特殊case处理, ignore enable_standalone_number
65+ # digit + union("百", "千", "万") + digit + unit
66+ unit_sp_case1 = [
67+ '年' ,
68+ '月' ,
69+ '个月' ,
70+ '周' ,
71+ '天' ,
72+ '位' ,
73+ '次' ,
74+ '个' ,
75+ '顿' ,
76+ ]
77+ if self .enable_0_to_9 :
78+ measure_sp = add_weight (
79+ ((digit + delete ('百' ) + add_weight (addzero ** 2 , 1.0 )) |
80+ (digit + delete ('千' ) + add_weight (addzero ** 3 , 1.0 )) |
81+ (digit + delete ('万' ) + add_weight (addzero ** 4 , 1.0 ))) +
82+ insert (' ' ) + digit + union (* unit_sp_case1 ), - 0.5 )
83+ else :
84+ measure_sp = add_weight (
85+ ((digit + delete ('百' ) + add_weight (addzero ** 2 , 1.0 )) |
86+ (digit + delete ('千' ) + add_weight (addzero ** 3 , 1.0 )) |
87+ (digit + delete ('万' ) + add_weight (addzero ** 4 , 1.0 ))) +
88+ digit_zh + union (* unit_sp_case1 ), - 0.5 )
89+
90+ tagger = insert ('value: "' ) + (measure | measure_sp
91+ | percent ) + insert ('"' )
6092 # 每小时十千米 => 10km/h, 每小时三十到三百一十一千米 => 30~311km/h
6193 tagger |= (insert ('denominator: "' ) + delete ('每' ) + units +
6294 insert ('" numerator: "' ) + measure + insert ('"' ))
0 commit comments