Skip to main content

HiFi-GAN

HiFi-GAN implementation
Created on December 21|Last edited on December 22


Abstract

Нужно было реализовать и обучить HiFi-GAN.

Подробности реализации

Структура репозитория:
bin:
download.py # Скрипт для загрузки данных
notebooks:
checks.ipynb # Проверка функционала
datasphere.ipynb # Запуск скриптов в ДатаСфере
resources:
chkpoints # Полные чекпоинты для перезапуска экспериментов
ids # ID запусков W&B
models # Чекпоинты генератора, отобранные по метрике на валидации
predicted # Результаты инференса
test_audio # Фудиофайлы для инференса
src:
config # Файлы конфигурации, содержащие все параметры экспериментов
dataset # Загрузчики данных
loops # Циклы обучения, валидации и т.д.
loss # Функции потерь
models # Архитектуры моделей
transforms # Преобразования аудио (мел спектрограммы и т.д.)
utils # Утилиты (сохранение и загрузка чекпоинтов, визуализация и т.д.)
train.py # Скрипт для обучения
inference.py # Скрипт для инференса

Датасет

LJSpeech-1.1
def __getitem__(self, i: int) -> Tensor:
audio, sample_rate = torchaudio.load(self.ids[i])
assert sample_rate == self.sample_rate
audio = audio / audio.max(dim=1, keepdim=True)[0] * self.wav_scale

if self.segment_size is not None and not self.validation:
if audio.shape[1] >= self.segment_size:
audio_start_bound = audio.shape[1] - self.segment_size
audio_start = np.random.randint(audio_start_bound)
audio = audio[:, audio_start:audio_start + self.segment_size]
else:
# I'm almost 100% sure this is never used
audio = F.pad(audio, (0, self.segment_size - audio.shape[1]), "constant")

if self.augmentations is not None:
audio = self.augmentations(audio)
Все просто, загружаем вавку и нормализуем ее, затем вырезаем небольшой кусочек из случайного места.
Естественно, грузить весь файл ради маленького кусочка кажется не очень эффективно, поэтому в checks.ipynb мы проверили, сколько времени занимает итерирование по всей обучающей выборке. Оказалось, что ~30 секунд при времени обучения по ~25 минут на эпоху. Так что все ок.

Пример данных из валидации. Сверху входная спектрограмма, посередине выходная.
Затем аудио файлы уже после формирования батча преобразуются во входную и выходную спектрограммы. Отличие в том, что на выходную не применяется клиппинг сверху, чтобы не терять градиенты. Как видно из рисунка, разница не очень заметна.

Модели

Генератор: стакаем ConvTranspose и MRF-блоки.
upsample_rates = [8, 8, 2, 2]
upsample_kernel_sizes = [16, 16, 4, 4]
upsample_initial_channel = 512
resblock_kernel_sizes = [3, 7, 11]
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
Дискриминатор: решили взять параметры не из статьи, а из кода.
MPD
# Parameters from the official implementation
mpd_periods = [2, 3, 5, 7, 11]
mpd_channels = [1, 32, 128, 512, 1024, 1024]
mpd_kernel_size = 5
mpd_stride = 3
mpd_use_spectral_norm = False
Примерная структура дискриминатора
class MultiPeriodDiscriminator(nn.Module):
def __init__(self, config):
super().__init__()
discriminators = [PeriodDiscriminator(config, period)
for period in config.mpd_periods]
self.discriminators = nn.ModuleList(discriminators)

def forward(self, real: Tensor, fake: Tensor) -> Dict[str, List]:
# ... do something ...
return {"mpd_real_outs": real_outs,
"mpd_fake_outs": fake_outs,
"mpd_real_features": real_features,
"mpd_fake_features": fake_features}
То есть дискриминатор возвращает словарь со всеми нужными выходами.
MSD
# Parameters from the official implementation
msd_channels = [1, 128, 128, 256, 512, 1024, 1024, 1024]
msd_kernel_sizes = [15, 41, 41, 41, 41, 41, 5]
msd_strides = [1, 2, 2, 4, 4, 1, 1]
msd_groups = [1, 4, 16, 16, 16, 16, 1]
Структура кода и входа-выхода примерно такая же.
Обе модели объединяются в одну обертку:
class Discriminator(nn.Module):
def __init__(self, config):
super().__init__()
self.mpd = MultiPeriodDiscriminator(config)
self.msd = MultiScaleDiscriminator(config)

def forward(self, real: Tensor, fake: Tensor) -> Dict[str, List]:
mpd_out = self.mpd(real, fake)
msd_out = self.msd(real, fake)
return {**mpd_out, **msd_out}
Это позволяет сделать код более запутанным структурированным.
Были так же реализованы хуки для подключения нормализаций и инициализации весов
def _apply_conv_hook(function, module: nn.Module):
name = module.__class__.__name__
if "Conv" in name:
function(module)

def conv_hook(function):
return partial(_apply_conv_hook, function)
Теперь не нужно вручную перебирать все модули.

Функции потерь

Тоже все распихано по отдельным классам.
class DiscriminatorLoss(nn.Module):
def forward(self, real_outs: List[Tensor], fake_outs: List[Tensor]) -> Tensor:
loss = 0
for real, fake in zip(real_outs, fake_outs):
# Classify real as 1 (real), fake as 0 (fake)
loss = loss + ((1 - real) ** 2).mean() + (fake ** 2).mean()
return loss

class FeatureLoss(nn.Module):
def forward(self, real_features: List[Tensor], fake_features: List[Tensor]) -> Tensor:
loss = 0
for real_feature_map, fake_feature_map in zip(real_features, fake_features):
for real_feature, fake_feature in zip(real_feature_map, fake_feature_map):
loss = loss + self.l1_loss(fake_feature, real_feature)
return loss

class GeneratorLoss(nn.Module):
def forward(self, fake_outs: List[Tensor]) -> Tensor:
loss = 0
for fake in fake_outs:
loss = loss + ((fake - 1) ** 2).mean() # Classify fake as 1 (real)
return loss
Тут вроде все ясно, обычный LSGAN + предложенную в статье функцию потерь на сдвижение признаков real и fake друг к другу, чтобы обманывать дискриминатор.

Метрики

Почему бы не посчитать какие-то еще метрики, кроме L1 потерь на спектрограммах? Добавили PSNR и SSIM на них. SSIM используем как валидационную метрику, то есть по ней сохраняем лучшие чекпоинты.

Обучение

# Get generator predictions
pred_audio = generator(input_mel)
pred_mel = mel_spectrogram_loss(pred_audio.squeeze(1))

# Pad targets to match predictions from generator
target_audio = pad_to_length(target_audio, pred_audio.shape[-1])
target_mel = pad_to_length(target_mel, pred_mel.shape[-1])

# Discriminator step
dis_optimizer.zero_grad()

# Discriminator update
discriminator.requires_grad(True)

dis_out = discriminator(target_audio, pred_audio.detach())
dis_losses = compute_losses(dis_criterion, dis_out)
dis_loss = compute_total_loss(dis_losses)
dis_loss.backward()
dis_optimizer.step()

# Generator step
gen_optimizer.zero_grad()

# Turn off discriminator gradients
discriminator.requires_grad(False)

# Get discriminator predictions
dis_out = discriminator(target_audio, pred_audio)

# Generator update
super_losses = compute_losses(super_criterion, target_mel, pred_mel)
gen_losses = compute_losses(gen_criterion, dis_out)
gen_loss = compute_total_loss({**super_losses, **gen_losses})
gen_loss.backward()
gen_optimizer.step()
Запихнули всю муторную работу с подсчетом нужных функций от нужных аргументов в конфигурацию с помощью KeySelector, который выбирает из возвращаемого дискриминатором словаря (помним его?) нужные ключи и посылает их в подсчет функций потерь.
Также реализованы возможности:
  • Клиппинг градиентов
  • Адаптивное отключение шага дискриминатора при его переобучении
  • Очень подробное логирование всех функций и метрик в W&B
Минусы: в старой версии pytorch, которая установлена в датасфере, при обучении течет память, очень медленно, но течет. В новой версии, которая указана в requirements.txt, подобная проблема отсутствует. Подозреваю, что дело в вычислении метрик torchmetrics.
UPD: Поставил новую версию torchmetrics и память течь перестала (см. график), так что подозрения оправдались.



Эксперименты

Вначале запускали на локалке, чтобы исправить все несметное количество багов, образовавшихся в процессе реализации. Памяти катастрофически не хватало, поэтому вот запуск только валидации без обучения.


В итоге все стало работать, и перешли в датасферу обучать финальную модель (до поседения).


Когда логов много, а не мало, легче себя успокоить, что ничего не сломалось. В графе mel_sample как обычно сверху предсказанная спектрограмма, а снизу целевая. Видно, как по ходу обучения она становится все более детальной.
Сгенерированный звук на валидации с тем же спикером довольно быстро становится очень неплохим, робовойса вообще нет, артефакты еле слышны.
UPD: На 10k шаге сменили препроцессинг мелспеки на исправленную версию. Что пронаблюдали:
  • Скорость падения потерь генератора (supervised и adversarial) увеличилась
  • Потери дискриминаторов скакнули вверх
  • Скорость роста валидационных метрик не увеличилась
Какие из этого можно сделать выводы
  • Генератору стало легче обманывать дискриминаторы. Скорее всего до этого они отличали сгенерированное аудио от настоящего по сайд-эффектам паддинга, которые сильно уменьшились с новым препроцессингом
  • Дискриминаторам теперь пришлось выучивать более сложные признаки
  • Возможно, это повлияло на итоговое качество, но без запуска эксперимента со старым препроцессингом на столько же эпох узнать не получится

Результаты

Но нас, конечно, интересует инференс на незнакомом спикере, и тут все несколько похуже. Послушать примеры можно тут.
Слышен робовойс и артефакты. С другой стороны, впечатляет, что модель, которой нужно обучаться сотни часов вычислительного времени, может так быстро давать что-то вполне приемлимое. Все слова вполне можно разобрать. Заметны концевые артефакты, связанные с паддингом: первое слово тише последующих. Рано или поздно моделька должна обучиться это исправлять, потому что на обычной валидации такой проблемы нет.
UPD: После дообучения с новым препроцессингом результаты стали заметно лучше. Робовойс почти пропал, остались только артефакты в виде хрипоты. Концевые артефакты заметно уменьшились. По ссылке выше теперь расположена новая версия, в скрипте для загрузки данных тоже загружается она.

Дополнительно

А что будет, если попытаться снизить доменный сдвиг тестового аудио? Для этого в скрипт инференса добавили возможность менять высоту звука. При увеличении высоты, по идее, голос станет больше похож на женский. Результаты при повышении высоты звука на +7 steps доступны в той же папке гугл диск. В целом улучшение получилось, артефакты хрипоты пропали, но, конечно, не идеально.
Еще вот по приколу вокодер версия funkytown, полученная с помощью модели с 30 эпохи


Воспроизводимость

Все необходимые действия описаны в README.