Skip to content

Commit 91c1de2

Browse files
alancuckinv-kkudrynski
authored andcommitted
[HiFi-GAN/PyT] Import amp_C (apex) only when necessary
1 parent ab3a0e4 commit 91c1de2

3 files changed

Lines changed: 36 additions & 31 deletions

File tree

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import amp_C
2+
import torch
3+
4+
5+
def apply_ema_decay(model, ema_model, decay):
6+
if not decay:
7+
return
8+
st = model.state_dict()
9+
add_module = hasattr(model, 'module') and not hasattr(ema_model, 'module')
10+
for k, v in ema_model.state_dict().items():
11+
if add_module and not k.startswith('module.'):
12+
k = 'module.' + k
13+
v.copy_(decay * v + (1 - decay) * st[k])
14+
15+
16+
def init_multi_tensor_ema(model, ema_model):
17+
model_weights = list(model.state_dict().values())
18+
ema_model_weights = list(ema_model.state_dict().values())
19+
ema_overflow_buf = torch.cuda.IntTensor([0])
20+
return model_weights, ema_model_weights, ema_overflow_buf
21+
22+
23+
def apply_multi_tensor_ema(decay, model_weights, ema_weights, overflow_buf):
24+
amp_C.multi_tensor_axpby(
25+
65536, overflow_buf, [ema_weights, model_weights, ema_weights],
26+
decay, 1-decay, -1)

PyTorch/SpeechSynthesis/HiFiGAN/common/utils.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949

5050
import soundfile # flac
5151

52-
import amp_C
5352
import matplotlib
5453

5554
matplotlib.use("Agg")
@@ -97,30 +96,6 @@ def adjust_fine_tuning_lr(args, ckpt_d):
9796
param_group['lr'] = new_v
9897

9998

100-
def apply_ema_decay(model, ema_model, decay):
101-
if not decay:
102-
return
103-
st = model.state_dict()
104-
add_module = hasattr(model, 'module') and not hasattr(ema_model, 'module')
105-
for k, v in ema_model.state_dict().items():
106-
if add_module and not k.startswith('module.'):
107-
k = 'module.' + k
108-
v.copy_(decay * v + (1 - decay) * st[k])
109-
110-
111-
def init_multi_tensor_ema(model, ema_model):
112-
model_weights = list(model.state_dict().values())
113-
ema_model_weights = list(ema_model.state_dict().values())
114-
ema_overflow_buf = torch.cuda.IntTensor([0])
115-
return model_weights, ema_model_weights, ema_overflow_buf
116-
117-
118-
def apply_multi_tensor_ema(decay, model_weights, ema_weights, overflow_buf):
119-
amp_C.multi_tensor_axpby(
120-
65536, overflow_buf, [ema_weights, model_weights, ema_weights],
121-
decay, 1-decay, -1)
122-
123-
12499
def init_distributed(args, world_size, rank):
125100
assert torch.cuda.is_available(), "Distributed mode requires CUDA."
126101
print(f"{args.local_rank}: Initializing distributed training")

PyTorch/SpeechSynthesis/HiFiGAN/train.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,10 @@ def main():
275275

276276
# setup EMA
277277
if args.ema_decay > 0:
278+
# burried import, requires apex
279+
from common.ema_utils import (apply_multi_tensor_ema,
280+
init_multi_tensor_ema)
281+
278282
gen_ema = models.get_model('HiFi-GAN', gen_config, 'cuda').cuda()
279283
mpd_ema = MultiPeriodDiscriminator(
280284
periods=args.mpd_periods,
@@ -316,9 +320,9 @@ def main():
316320
val_kwargs=dict(split=False),
317321
batch_size=1)
318322
if args.ema_decay > 0.0:
319-
gen_ema_params = utils.init_multi_tensor_ema(gen, gen_ema)
320-
mpd_ema_params = utils.init_multi_tensor_ema(mpd, mpd_ema)
321-
msd_ema_params = utils.init_multi_tensor_ema(msd, msd_ema)
323+
gen_ema_params = init_multi_tensor_ema(gen, gen_ema)
324+
mpd_ema_params = init_multi_tensor_ema(mpd, mpd_ema)
325+
msd_ema_params = init_multi_tensor_ema(msd, msd_ema)
322326

323327
epochs_done = 0
324328

@@ -428,9 +432,9 @@ def main():
428432
metrics.accumulate()
429433

430434
if args.ema_decay > 0.0:
431-
utils.apply_multi_tensor_ema(args.ema_decay, *gen_ema_params)
432-
utils.apply_multi_tensor_ema(args.ema_decay, *mpd_ema_params)
433-
utils.apply_multi_tensor_ema(args.ema_decay, *msd_ema_params)
435+
apply_multi_tensor_ema(args.ema_decay, *gen_ema_params)
436+
apply_multi_tensor_ema(args.ema_decay, *mpd_ema_params)
437+
apply_multi_tensor_ema(args.ema_decay, *msd_ema_params)
434438

435439
metrics.finish_iter() # done accumulating
436440
if iters_all % args.step_logs_interval == 0:

0 commit comments

Comments
 (0)