This repository contains the official PyTorch implementations for our paper:
- Yuchen Hu, Chen Chen, Ruizhe Li, Qiushi Zhu, Eng Siong Chng. "Noise-aware Speech Enhancement using Diffusion Probabilistic Model".
Our code is based on prior work SGMSE+.
- Create a new virtual environment with Python 3.8 (we have not tested other Python versions, but they may work).
- Install the package dependencies via
pip install -r requirements.txt. - If using W&B logging (default):
- Set up a wandb.ai account
- Log in via
wandb loginbefore running our code.
- If not using W&B logging:
- Pass the option
--no_wandbtotrain.py. - Your logs will be stored as local TensorBoard logs. Run
tensorboard --logdir logs/to see them.
- Pass the option
- We release pretrained checkpoint for the model trained on VoiceBank-DEMAND, as in the paper.
- We also provide testing samples before and after NASE processing for comparison.
Usage:
- For resuming training, you can use the
--resume_from_checkpointoption oftrain.py. - For evaluating these checkpoints, use the
--ckptoption ofenhancement.py(see section Evaluation below).
Training is done by executing train.py. A minimal running example with default settings can be run with:
python train.py --base_dir <your_base_dir> --inject_type <inject_type> --pretrain_class_model <pretrained_beats>where your_base_dir should be a path to a folder containing subdirectories train/ and valid/ (optionally test/ as well). Each subdirectory must itself have two subdirectories clean/ and noisy/, with the same filenames present in both. We currently only support training with .wav files.
inject_type should be chosen from ["addition", "concat", "cross-attention"].
pretrained_beats should be the path to pre-trained BEATs.
The full command is also included in train.sh.
To see all available training options, run python train.py --help.
To evaluate on a test set, run
python enhancement.py --test_dir <your_test_dir> --enhanced_dir <your_enhanced_dir> --ckpt <path_to_model_checkpoint> --pretrain_class_model <pretrained_beats>to generate the enhanced .wav files, and subsequently run
python calc_metrics.py --test_dir <your_test_dir> --enhanced_dir <your_enhanced_dir>to calculate and output the instrumental metrics.
Both scripts should receive the same --test_dir and --enhanced_dir parameters.
The --cpkt parameter of enhancement.py should be the path to a trained model checkpoint, as stored by the logger in logs/.
The --pretrain_class_model should be the path to pre-trained BEATs.
You may refer to our full commands included in enhancement.sh and calc_metrics.sh.
We kindly hope you can cite our paper in your publication when using our research or code:
@inproceedings{hu2024noise,
title={Noise-aware Speech Enhancement using Diffusion Probabilistic Model},
author={Hu, Yuchen and Chen, Chen and Li, Ruizhe and Zhu, Qiushi and Chng, Eng Siong},
booktitle={INTERSPEECH},
year={2024}
}