Abstract: Score-based generative models (SGMs) have recently shown impressive results for difficult generative tasks such as the unconditional and conditional generation of natural images and audio signals. In this work, we extend these models to the complex short-time Fourier transform (STFT) domain, proposing a novel training task for speech enhancement using a complex-valued deep neural network. We derive this training task within the formalism of stochastic differential equations (SDEs), thereby enabling the use of predictor-corrector samplers. We provide alternative formulations inspired by previous publications on using generative diffusion models for speech enhancement, avoiding the need for any prior assumptions on the noise distribution and making the training task purely generative which, as we show, results in improved enhancement performance.
Speech Enhancement and Dereverberation with Diffusion-based Generative Models
This repository contains the official PyTorch implementations for the 2022 papers:
- Speech Enhancement with Score-Based Generative Models in the Complex STFT Domain, 2022 [1]
- Speech Enhancement and Dereverberation with Diffusion-Based Generative Models, 2022 [2]
Audio examples and further supplementary materials are available on our project page.
Interspeech 2022
Come talk to us at Interspeech 2022 in Incheon, Korea! We'll be at the poster session Wed-P-OS-6-1, Wednesday 21.09.2022, 10:00-12:00.
Installation
- Create a new virtual environment with Python 3.8 (we have not tested other Python versions, but they may work).
- Install the package dependencies via
pip install -r requirements.txt
. - If using W&B logging (default):
- Set up a wandb.ai account
- Log in via
wandb login
before running our code.
- If not using W&B logging:
- Pass the option
--no_wandb
totrain.py
. - Your logs will be stored as local TensorBoard logs. Run
tensorboard --logdir logs/
to see them.
- Pass the option
Pretrained checkpoints
- For the Speech Enhancement task, we provide pretrained checkpoints for the models trained on VoiceBank-DEMAND and WSJ0-CHiME3, as in the paper. They can be downloaded here.
- For the Dereverberation task, we provide a checkpoint trained on our WSJ0-REVERB dataset. It can be downloaded here.
- Note that this checkpoint works better with sampler settings
--N 50 --snr 0.33
.
- Note that this checkpoint works better with sampler settings
Usage:
- For resuming training, you can use the
--resume_from_checkpoint
option oftrain.py
. - For evaluating these checkpoints, use the
--ckpt
option ofenhancement.py
(see section Evaluation below).
Training
Training is done by executing train.py
. A minimal running example with default settings (as in our paper [2]) can be run with
python train.py --base_dir <your_base_dir>
where your_base_dir
should be a path to a folder containing subdirectories train/
and valid/
(optionally test/
as well). Each subdirectory must itself have two subdirectories clean/
and noisy/
, with the same filenames present in both. We currently only support training with .wav
files.
To see all available training options, run python train.py --help
. Note that the available options for the SDE and the backbone network change depending on which SDE and backbone you use. These can be set through the --sde
and --backbone
options.
Note:
- Our journal preprint [2] uses
--backbone ncsnpp
. - Our Interspeech paper [1] uses
--backbone dcunet
. You need to pass--n_fft 512
to make it work.- Also note that the default parameters for the spectrogram transformation in this repository are slightly different from the ones listed in the first (Interspeech) paper (
--spec_factor 0.15
rather than--spec_factor 0.333
), but we've found the value in this repository to generally perform better for both models [1] and [2].
- Also note that the default parameters for the spectrogram transformation in this repository are slightly different from the ones listed in the first (Interspeech) paper (
Evaluation
To evaluate on a test set, run
python enhancement.py --test_dir <your_test_dir> --enhanced_dir <your_enhanced_dir> --ckpt <path_to_model_checkpoint>
to generate the enhanced .wav files, and subsequently run
python calc_metrics.py --test_dir <your_test_dir> --enhanced_dir <your_enhanced_dir>
to calculate and output the instrumental metrics.
Both scripts should receive the same --test_dir
and --enhanced_dir
parameters. The --cpkt
parameter of enhancement.py
should be the path to a trained model checkpoint, as stored by the logger in logs/
.
Citations / References
We kindly ask you to cite our papers in your publication when using any of our research or code:
[1] Simon Welker, Julius Richter and Timo Gerkmann. "Speech Enhancement with Score-Based Generative Models in the Complex STFT Domain", ISCA Interspeech, 2022.
[2] Julius Richter, Simon Welker, Jean-Marie Lemercier, Bunlong Lay and Timo Gerkmann. "Speech Enhancement and Dereverberation with Diffusion-Based Generative Models", arXiv preprint arXiv:2208.05830, 2022.
The paper [2] has been submitted to a journal and is currently under review. The appropriate citation for it may therefore change in the future.