Skip to content

Commit 87aa4b0

Browse files
committed
Merge: [QuartzNet/PyT] Support NeMo checkpoints
2 parents e5efe02 + b472e61 commit 87aa4b0

6 files changed

Lines changed: 286 additions & 15 deletions

File tree

PyTorch/SpeechRecognition/QuartzNet/README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ This repository provides a script and recipe to train the QuartzNet model to ach
1313
* [Enabling mixed precision](#enabling-mixed-precision)
1414
* [Enabling TF32](#enabling-tf32)
1515
* [Glossary](#glossary)
16+
* [Language support and NeMo compatibility](#language-support-and-nemo-compatibility)
1617
- [Setup](#setup)
1718
* [Requirements](#requirements)
1819
- [Quick Start Guide](#quick-start-guide)
@@ -144,6 +145,23 @@ Assigns a probability distribution over a sequence of words. Given a sequence of
144145
**Pre-training**
145146
Training a model on vast amounts of data on the same (or different) task to build general understandings.
146147

148+
### Language support and NeMo compatibility
149+
150+
This repository allows to train and run models in laguages other than English.
151+
152+
During inference, QuartzNet models trained with [NVIDIA NeMo](https://github.com/NVIDIA/NeMo) can also be used, for instance one of pre-trained models
153+
for Catalan, French, German, Italian, Mandarin Chinese, Polish, Russian or Spanish available on [NGC](https://ngc.nvidia.com/).
154+
To download automatically, run:
155+
```bash
156+
bash scripts/download_quartznet.sh [ca|fr|de|it|zh|pl|ru|es]
157+
```
158+
159+
Pre-trained models can be explicitly converted from the `.nemo` checkpoint format to `.pt` and vice versa.
160+
For more details, run:
161+
```bash
162+
python nemo_dle_model_converter.py --help
163+
```
164+
147165
## Setup
148166

149167
The following section lists the requirements that you need to meet in order to start training the QuartzNet model.

PyTorch/SpeechRecognition/QuartzNet/common/audio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(self, filename, target_sr=None, int_values=False, offset=0,
7575
self._samples = samples
7676
self._sample_rate = sample_rate
7777
if self._samples.ndim >= 2:
78-
self._samples = np.mean(self._samples, 1)
78+
self._samples = np.mean(self._samples, 0)
7979

8080
def __eq__(self, other):
8181
"""Return whether two objects are equal."""

PyTorch/SpeechRecognition/QuartzNet/inference.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@
3535
SingleAudioDataset)
3636
from common.features import BaseFeatures, FilterbankFeatures
3737
from common.helpers import print_once, process_evaluation_epoch
38-
from quartznet.model import GreedyCTCDecoder, QuartzNet
3938
from common.tb_dllogger import stdout_metric_format, unique_log_fpath
39+
from nemo_dle_model_converter import load_nemo_ckpt
40+
from quartznet.model import GreedyCTCDecoder, QuartzNet
4041

4142

4243
def get_parser():
@@ -189,7 +190,25 @@ def main():
189190
distrib.init_process_group(backend='nccl', init_method='env://')
190191
print_once(f'Inference with {distrib.get_world_size()} GPUs')
191192

192-
cfg = config.load(args.model_config)
193+
if args.ckpt is not None:
194+
print(f'Loading the model from {args.ckpt} ...')
195+
print(f'{args.model_config} will be overriden.')
196+
if args.ckpt.lower().endswith('.nemo'):
197+
ckpt, cfg = load_nemo_ckpt(args.ckpt)
198+
else:
199+
cfg = config.load(args.model_config)
200+
ckpt = torch.load(args.ckpt, map_location='cpu')
201+
202+
sd_key = 'ema_state_dict' if args.ema else 'state_dict'
203+
if args.ema and 'ema_state_dict' not in ckpt:
204+
print(f'WARNING: EMA weights are unavailable in {args.ckpt}.')
205+
sd_key = 'state_dict'
206+
state_dict = ckpt[sd_key]
207+
208+
else:
209+
cfg = config.load(args.model_config)
210+
state_dict = None
211+
193212
config.apply_config_overrides(cfg, args)
194213

195214
symbols = helpers.add_ctc_blank(cfg['labels'])
@@ -267,11 +286,7 @@ def main():
267286
model = QuartzNet(encoder_kw=config.encoder(cfg),
268287
decoder_kw=config.decoder(cfg, n_classes=len(symbols)))
269288

270-
if args.ckpt is not None:
271-
print(f'Loading the model from {args.ckpt} ...')
272-
checkpoint = torch.load(args.ckpt, map_location="cpu")
273-
key = 'ema_state_dict' if args.ema else 'state_dict'
274-
state_dict = checkpoint[key]
289+
if state_dict is not None:
275290
model.load_state_dict(state_dict, strict=True)
276291

277292
model.to(device)
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
import argparse
2+
import io
3+
import sys
4+
from copy import deepcopy
5+
from functools import reduce
6+
from pathlib import Path
7+
from subprocess import CalledProcessError, check_output
8+
9+
import torch
10+
import yaml
11+
12+
import quartznet.config
13+
from common import helpers
14+
from common.features import FilterbankFeatures
15+
from quartznet.config import load as load_yaml
16+
from quartznet.model import QuartzNet, MaskedConv1d
17+
18+
19+
# Corresponding DLE <-> NeMo config keys
20+
cfg_key_map = {
21+
("input_val", "audio_dataset", "sample_rate"): ("preprocessor", "sample_rate"),
22+
("input_val", "filterbank_features", "dither"): ("preprocessor", "dither"),
23+
("input_val", "filterbank_features", "frame_splicing"): ("preprocessor", "frame_splicing"),
24+
("input_val", "filterbank_features", "n_fft"): ("preprocessor", "n_fft"),
25+
("input_val", "filterbank_features", "n_filt"): ("preprocessor", "features"),
26+
("input_val", "filterbank_features", "normalize"): ("preprocessor", "normalize"),
27+
("input_val", "filterbank_features", "sample_rate"): ("preprocessor", "sample_rate"),
28+
("input_val", "filterbank_features", "window"): ("preprocessor", "window"),
29+
("input_val", "filterbank_features", "window_size"): ("preprocessor", "window_size"),
30+
("input_val", "filterbank_features", "window_stride"): ("preprocessor", "window_stride"),
31+
("labels",): ("decoder", "vocabulary"),
32+
("quartznet", "decoder", "in_feats"): ("decoder", "feat_in"),
33+
("quartznet", "encoder", "activation"): ("encoder", "activation"),
34+
("quartznet", "encoder", "blocks"): ("encoder", "jasper"),
35+
("quartznet", "encoder", "frame_splicing"): ("preprocessor", "frame_splicing"),
36+
("quartznet", "encoder", "in_feats"): ("encoder", "feat_in"),
37+
("quartznet", "encoder", "use_conv_masks"): ("encoder", "conv_mask"),
38+
}
39+
40+
41+
def load_nemo_ckpt(fpath):
42+
"""Make a DeepLearningExamples state_dict and config from a .nemo file."""
43+
try:
44+
cmd = ['tar', 'Oxzf', fpath, './model_config.yaml']
45+
nemo_cfg = yaml.safe_load(io.BytesIO(check_output(cmd)))
46+
47+
cmd = ['tar', 'Oxzf', fpath, './model_weights.ckpt']
48+
ckpt = torch.load(io.BytesIO(check_output(cmd)), map_location="cpu")
49+
50+
except (FileNotFoundError, CalledProcessError):
51+
print('WARNING: Could not uncompress with tar. '
52+
'Falling back to the tarfile module (might take a few minutes).')
53+
import tarfile
54+
with tarfile.open(fpath, "r:gz") as tar:
55+
f = tar.extractfile(tar.getmember("./model_config.yaml"))
56+
nemo_cfg = yaml.safe_load(f)
57+
58+
f = tar.extractfile(tar.getmember("./model_weights.ckpt"))
59+
ckpt = torch.load(f, map_location="cpu")
60+
61+
remap = lambda k: (k.replace("encoder.encoder", "encoder.layers")
62+
.replace("decoder.decoder_layers", "decoder.layers")
63+
.replace("conv.weight", "weight"))
64+
dle_ckpt = {'state_dict': {remap(k): v for k, v in ckpt.items()
65+
if "preproc" not in k}}
66+
dle_cfg = config_from_nemo(nemo_cfg)
67+
return dle_ckpt, dle_cfg
68+
69+
70+
def save_nemo_ckpt(dle_ckpt, dle_cfg, dest_path):
71+
"""Save a DeepLearningExamples model as a .nemo file."""
72+
cfg = deepcopy(dle_cfg)
73+
74+
dle_ckpt = torch.load(dle_ckpt, map_location="cpu")["ema_state_dict"]
75+
76+
# Build a DLE model instance and fill with weights
77+
symbols = helpers.add_ctc_blank(cfg['labels'])
78+
enc_kw = quartznet.config.encoder(cfg)
79+
dec_kw = quartznet.config.decoder(cfg, n_classes=len(symbols))
80+
model = QuartzNet(enc_kw, dec_kw)
81+
model.load_state_dict(dle_ckpt, strict=True)
82+
83+
# Reaname core modules, e.g., encoder.layers -> encoder.encoder
84+
model.encoder._modules['encoder'] = model.encoder._modules.pop('layers')
85+
model.decoder._modules['decoder_layers'] = model.decoder._modules.pop('layers')
86+
87+
# MaskedConv1d is made via composition in NeMo, and via inheritance in DLE
88+
# Params for MaskedConv1d in NeMo have an additional '.conv.' infix
89+
def rename_convs(module):
90+
for name in list(module._modules.keys()):
91+
submod = module._modules[name]
92+
93+
if isinstance(submod, MaskedConv1d):
94+
module._modules[f'{name}.conv'] = module._modules.pop(name)
95+
else:
96+
rename_convs(submod)
97+
98+
rename_convs(model.encoder.encoder)
99+
100+
# Use FilterbankFeatures to calculate fbanks and store with model weights
101+
feature_processor = FilterbankFeatures(
102+
**dle_cfg['input_val']['filterbank_features'])
103+
104+
nemo_ckpt = model.state_dict()
105+
nemo_ckpt["preprocessor.featurizer.fb"] = feature_processor.fb
106+
nemo_ckpt["preprocessor.featurizer.window"] = feature_processor.window
107+
108+
nemo_cfg = config_to_nemo(dle_cfg)
109+
110+
# Prepare the directory for zipping
111+
ckpt_files = dest_path / "ckpt_files"
112+
ckpt_files.mkdir(exist_ok=True, parents=False)
113+
with open(ckpt_files / "model_config.yaml", "w") as f:
114+
yaml.dump(nemo_cfg, f)
115+
torch.save(nemo_ckpt, ckpt_files / "model_weights.ckpt")
116+
117+
with tarfile.open(dest_path / "quartznet.nemo", "w:gz") as tar:
118+
tar.add(ckpt_files, arcname="./")
119+
120+
121+
def save_dle_ckpt(ckpt, cfg, dest_dir):
122+
torch.save(ckpt, dest_dir / "model.pt")
123+
with open(dest_dir / "model_config.yaml", "w") as f:
124+
yaml.dump(cfg, f)
125+
126+
127+
def set_nested_item(tgt, src, tgt_keys, src_keys):
128+
"""Assigns nested dict keys, e.g., d1[a][b][c] = d2[e][f][g][h]."""
129+
tgt_nested = reduce(lambda d, k: d[k], tgt_keys[:-1], tgt)
130+
tgt_nested[tgt_keys[-1]] = reduce(lambda d, k: d[k], src_keys, src)
131+
132+
133+
def config_from_nemo(nemo_cfg):
134+
"""Convert a DeepLearningExamples config to a NeMo format."""
135+
dle_cfg = {
136+
'name': 'QuartzNet',
137+
'input_val': {
138+
'audio_dataset': {
139+
'normalize_transcripts': True,
140+
},
141+
'filterbank_features': {
142+
'pad_align': 16,
143+
},
144+
},
145+
'quartznet': {
146+
'decoder': {},
147+
'encoder': {},
148+
},
149+
}
150+
151+
for dle_keys, nemo_keys in cfg_key_map.items():
152+
try:
153+
set_nested_item(dle_cfg, nemo_cfg, dle_keys, nemo_keys)
154+
except KeyError:
155+
print(f'WARNING: Could not load config {nemo_keys} as {dle_keys}.')
156+
157+
# mapping kernel_size is not expressable with cfg_map
158+
for block in dle_cfg["quartznet"]["encoder"]["blocks"]:
159+
block["kernel_size"] = block.pop("kernel")
160+
161+
return dle_cfg
162+
163+
164+
def config_to_nemo(dle_cfg):
165+
"""Convert a DeepLearningExamples config to a NeMo format."""
166+
nemo_cfg = {
167+
"target": "nemo.collections.asr.models.ctc_models.EncDecCTCModel",
168+
"dropout": 0.0,
169+
"preprocessor": {
170+
"_target_": "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor",
171+
"stft_conv": False,
172+
},
173+
"encoder": {
174+
"_target_": "nemo.collections.asr.modules.ConvASREncoder",
175+
"jasper": {}
176+
},
177+
"decoder": {
178+
"_target_": "nemo.collections.asr.modules.ConvASRDecoder",
179+
},
180+
}
181+
182+
for dle_keys, nemo_keys in cfg_key_map.items():
183+
try:
184+
set_nested_item(nemo_cfg, dle_cfg, nemo_keys, dle_keys)
185+
except KeyError:
186+
print(f"WARNING: Could not load config {dle_keys} as {nemo_keys}.")
187+
188+
nemo_cfg["sample_rate"] = nemo_cfg["preprocessor"]["sample_rate"]
189+
nemo_cfg["repeat"] = nemo_cfg["encoder"]["jasper"][1]["repeat"]
190+
nemo_cfg["separable"] = nemo_cfg["encoder"]["jasper"][1]["separable"]
191+
nemo_cfg["labels"] = nemo_cfg["decoder"]["vocabulary"]
192+
nemo_cfg["decoder"]["num_classes"] = len(nemo_cfg["decoder"]["vocabulary"])
193+
194+
# mapping kernel_size is not expressable with cfg_map
195+
for block in nemo_cfg["encoder"]["jasper"]:
196+
if "kernel_size" in block:
197+
block["kernel"] = block.pop("kernel_size")
198+
199+
return nemo_cfg
200+
201+
202+
if __name__ == "__main__":
203+
parser = argparse.ArgumentParser(description="QuartzNet DLE <-> NeMo model converter.")
204+
parser.add_argument("source_model", type=Path,
205+
help="A DLE or NeMo QuartzNet model to be converted (.pt or .nemo, respectively)")
206+
parser.add_argument("dest_dir", type=Path, help="Destination directory")
207+
parser.add_argument("--dle_config_yaml", type=Path,
208+
help="A DLE config .yaml file, required only to convert DLE -> NeMo")
209+
args = parser.parse_args()
210+
211+
ext = args.source_model.suffix.lower()
212+
if ext == ".nemo":
213+
ckpt, cfg = load_nemo_ckpt(args.source_model)
214+
save_dle_ckpt(ckpt, cfg, args.dest_dir)
215+
216+
elif ext == ".pt":
217+
dle_cfg = load_yaml(args.dle_config_yaml)
218+
save_nemo_ckpt(args.source_model, dle_cfg, args.dest_dir)
219+
220+
else:
221+
raise ValueError(f"Unknown extension {ext}.")
222+
223+
print('Converted succesfully.')

PyTorch/SpeechRecognition/QuartzNet/scripts/download_quartznet.sh

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,37 @@
22

33
set -e
44

5-
: ${MODEL_DIR:="pretrained_models/quartznet"}
6-
MODEL_ZIP="quartznet_pyt_ckpt_amp_21.03.0.zip"
7-
MODEL="nvidia_quartznet_210504.pt"
8-
MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/quartznet_pyt_ckpt_amp/versions/21.03.0/zip"
5+
: ${LANGUAGE:=${1:-en}}
6+
: ${MODEL_DIR:="pretrained_models/quartznet_${LANGUAGE}"}
7+
8+
case $LANGUAGE in
9+
en)
10+
MODEL="nvidia_quartznet_210504.pt"
11+
MODEL_ZIP="quartznet_pyt_ckpt_amp_21.03.0.zip"
12+
MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/quartznet_pyt_ckpt_amp/versions/21.03.0/zip"
13+
;;
14+
ca|de|es|fr|it|pl|ru|zh)
15+
MODEL="stt_${LANGUAGE}_quartznet15x5.nemo"
16+
MODEL_URL="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_${LANGUAGE}_quartznet15x5/versions/1.0.0rc1/zip"
17+
MODEL_ZIP="stt_${LANGUAGE}_quartznet15x5_1.0.0rc1.zip"
18+
;;
19+
*)
20+
echo "Unsupported language $LANGUAGE"
21+
exit 1
22+
;;
23+
esac
924

1025
mkdir -p "$MODEL_DIR"
1126

1227
if [ ! -f "${MODEL_DIR}/${MODEL_ZIP}" ]; then
1328
echo "Downloading ${MODEL_ZIP} ..."
14-
wget -qO ${MODEL_DIR}/${MODEL_ZIP} ${MODEL_URL} \
29+
wget -O ${MODEL_DIR}/${MODEL_ZIP} ${MODEL_URL} \
1530
|| { echo "ERROR: Failed to download ${MODEL_ZIP} from NGC"; exit 1; }
1631
fi
1732

1833
if [ ! -f "${MODEL_DIR}/${MODEL}" ]; then
1934
echo "Extracting ${MODEL} ..."
20-
unzip -qo ${MODEL_DIR}/${MODEL_ZIP} -d ${MODEL_DIR} \
35+
unzip -o ${MODEL_DIR}/${MODEL_ZIP} -d ${MODEL_DIR} \
2136
|| { echo "ERROR: Failed to extract ${MODEL_ZIP}"; exit 1; }
2237

2338
echo "OK"

PyTorch/SpeechRecognition/QuartzNet/scripts/inference.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
: ${DATA_DIR:=${1:-"/datasets/LibriSpeech"}}
1818
: ${MODEL_CONFIG:=${2:-"configs/quartznet15x5_speedp-online-1.15_speca.yaml"}}
1919
: ${OUTPUT_DIR:=${3:-"/results"}}
20-
: ${CHECKPOINT:=${4:-"pretrained_models/quartznet/nvidia_quartznet_210504.pt"}}
20+
: ${CHECKPOINT:=${4:-"pretrained_models/quartznet_en/nvidia_quartznet_210504.pt"}}
2121
: ${DATASET:="test-other"}
2222
: ${LOG_FILE:=""}
2323
: ${CUDNN_BENCHMARK:=false}

0 commit comments

Comments
 (0)