HiFi-GAN
HiFi-GAN implementation
Created on December 21|Last edited on December 22
Comment
AbstractПодробности реализацииДатасетМоделиФункции потерьМетрикиОбучениеЭкспериментыРезультатыДополнительноВоспроизводимость
Abstract
Подробности реализации
Структура репозитория:
bin:download.py # Скрипт для загрузки данныхnotebooks:checks.ipynb # Проверка функционалаdatasphere.ipynb # Запуск скриптов в ДатаСфереresources:chkpoints # Полные чекпоинты для перезапуска экспериментовids # ID запусков W&Bmodels # Чекпоинты генератора, отобранные по метрике на валидацииpredicted # Результаты инференсаtest_audio # Фудиофайлы для инференсаsrc:config # Файлы конфигурации, содержащие все параметры экспериментовdataset # Загрузчики данныхloops # Циклы обучения, валидации и т.д.loss # Функции потерьmodels # Архитектуры моделейtransforms # Преобразования аудио (мел спектрограммы и т.д.)utils # Утилиты (сохранение и загрузка чекпоинтов, визуализация и т.д.)train.py # Скрипт для обученияinference.py # Скрипт для инференса
Датасет
LJSpeech-1.1
def __getitem__(self, i: int) -> Tensor:audio, sample_rate = torchaudio.load(self.ids[i])assert sample_rate == self.sample_rateaudio = audio / audio.max(dim=1, keepdim=True)[0] * self.wav_scaleif self.segment_size is not None and not self.validation:if audio.shape[1] >= self.segment_size:audio_start_bound = audio.shape[1] - self.segment_sizeaudio_start = np.random.randint(audio_start_bound)audio = audio[:, audio_start:audio_start + self.segment_size]else:# I'm almost 100% sure this is never usedaudio = F.pad(audio, (0, self.segment_size - audio.shape[1]), "constant")if self.augmentations is not None:audio = self.augmentations(audio)
Все просто, загружаем вавку и нормализуем ее, затем вырезаем небольшой кусочек из случайного места.
Естественно, грузить весь файл ради маленького кусочка кажется не очень эффективно, поэтому в checks.ipynb мы проверили, сколько времени занимает итерирование по всей обучающей выборке. Оказалось, что ~30 секунд при времени обучения по ~25 минут на эпоху. Так что все ок.


Пример данных из валидации. Сверху входная спектрограмма, посередине выходная.
Затем аудио файлы уже после формирования батча преобразуются во входную и выходную спектрограммы. Отличие в том, что на выходную не применяется клиппинг сверху, чтобы не терять градиенты. Как видно из рисунка, разница не очень заметна.
Модели
Генератор: стакаем ConvTranspose и MRF-блоки.
upsample_rates = [8, 8, 2, 2]upsample_kernel_sizes = [16, 16, 4, 4]upsample_initial_channel = 512resblock_kernel_sizes = [3, 7, 11]resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
Дискриминатор: решили взять параметры не из статьи, а из кода.
MPD
# Parameters from the official implementationmpd_periods = [2, 3, 5, 7, 11]mpd_channels = [1, 32, 128, 512, 1024, 1024]mpd_kernel_size = 5mpd_stride = 3mpd_use_spectral_norm = False
Примерная структура дискриминатора
class MultiPeriodDiscriminator(nn.Module):def __init__(self, config):super().__init__()discriminators = [PeriodDiscriminator(config, period)for period in config.mpd_periods]self.discriminators = nn.ModuleList(discriminators)def forward(self, real: Tensor, fake: Tensor) -> Dict[str, List]:# ... do something ...return {"mpd_real_outs": real_outs,"mpd_fake_outs": fake_outs,"mpd_real_features": real_features,"mpd_fake_features": fake_features}
То есть дискриминатор возвращает словарь со всеми нужными выходами.
MSD
# Parameters from the official implementationmsd_channels = [1, 128, 128, 256, 512, 1024, 1024, 1024]msd_kernel_sizes = [15, 41, 41, 41, 41, 41, 5]msd_strides = [1, 2, 2, 4, 4, 1, 1]msd_groups = [1, 4, 16, 16, 16, 16, 1]
Структура кода и входа-выхода примерно такая же.
Обе модели объединяются в одну обертку:
class Discriminator(nn.Module):def __init__(self, config):super().__init__()self.mpd = MultiPeriodDiscriminator(config)self.msd = MultiScaleDiscriminator(config)def forward(self, real: Tensor, fake: Tensor) -> Dict[str, List]:mpd_out = self.mpd(real, fake)msd_out = self.msd(real, fake)return {**mpd_out, **msd_out}
Это позволяет сделать код более запутанным структурированным.
Были так же реализованы хуки для подключения нормализаций и инициализации весов
def _apply_conv_hook(function, module: nn.Module):name = module.__class__.__name__if "Conv" in name:function(module)def conv_hook(function):return partial(_apply_conv_hook, function)
Теперь не нужно вручную перебирать все модули.
Функции потерь
Тоже все распихано по отдельным классам.
class DiscriminatorLoss(nn.Module):def forward(self, real_outs: List[Tensor], fake_outs: List[Tensor]) -> Tensor:loss = 0for real, fake in zip(real_outs, fake_outs):# Classify real as 1 (real), fake as 0 (fake)loss = loss + ((1 - real) ** 2).mean() + (fake ** 2).mean()return lossclass FeatureLoss(nn.Module):def forward(self, real_features: List[Tensor], fake_features: List[Tensor]) -> Tensor:loss = 0for real_feature_map, fake_feature_map in zip(real_features, fake_features):for real_feature, fake_feature in zip(real_feature_map, fake_feature_map):loss = loss + self.l1_loss(fake_feature, real_feature)return lossclass GeneratorLoss(nn.Module):def forward(self, fake_outs: List[Tensor]) -> Tensor:loss = 0for fake in fake_outs:loss = loss + ((fake - 1) ** 2).mean() # Classify fake as 1 (real)return loss
Тут вроде все ясно, обычный LSGAN + предложенную в статье функцию потерь на сдвижение признаков real и fake друг к другу, чтобы обманывать дискриминатор.
Метрики
Почему бы не посчитать какие-то еще метрики, кроме L1 потерь на спектрограммах? Добавили PSNR и SSIM на них. SSIM используем как валидационную метрику, то есть по ней сохраняем лучшие чекпоинты.
Обучение
# Get generator predictionspred_audio = generator(input_mel)pred_mel = mel_spectrogram_loss(pred_audio.squeeze(1))# Pad targets to match predictions from generatortarget_audio = pad_to_length(target_audio, pred_audio.shape[-1])target_mel = pad_to_length(target_mel, pred_mel.shape[-1])# Discriminator stepdis_optimizer.zero_grad()# Discriminator updatediscriminator.requires_grad(True)dis_out = discriminator(target_audio, pred_audio.detach())dis_losses = compute_losses(dis_criterion, dis_out)dis_loss = compute_total_loss(dis_losses)dis_loss.backward()dis_optimizer.step()# Generator stepgen_optimizer.zero_grad()# Turn off discriminator gradientsdiscriminator.requires_grad(False)# Get discriminator predictionsdis_out = discriminator(target_audio, pred_audio)# Generator updatesuper_losses = compute_losses(super_criterion, target_mel, pred_mel)gen_losses = compute_losses(gen_criterion, dis_out)gen_loss = compute_total_loss({**super_losses, **gen_losses})gen_loss.backward()gen_optimizer.step()
Запихнули всю муторную работу с подсчетом нужных функций от нужных аргументов в конфигурацию с помощью KeySelector, который выбирает из возвращаемого дискриминатором словаря (помним его?) нужные ключи и посылает их в подсчет функций потерь.
Также реализованы возможности:
- Клиппинг градиентов
- Адаптивное отключение шага дискриминатора при его переобучении
- Очень подробное логирование всех функций и метрик в W&B
Минусы: в старой версии pytorch, которая установлена в датасфере, при обучении течет память, очень медленно, но течет. В новой версии, которая указана в requirements.txt, подобная проблема отсутствует. Подозреваю, что дело в вычислении метрик torchmetrics.
UPD: Поставил новую версию torchmetrics и память течь перестала (см. график), так что подозрения оправдались.
Эксперименты
Вначале запускали на локалке, чтобы исправить все несметное количество багов, образовавшихся в процессе реализации. Памяти катастрофически не хватало, поэтому вот запуск только валидации без обучения.
В итоге все стало работать, и перешли в датасферу обучать финальную модель (до поседения).
Когда логов много, а не мало, легче себя успокоить, что ничего не сломалось. В графе mel_sample как обычно сверху предсказанная спектрограмма, а снизу целевая. Видно, как по ходу обучения она становится все более детальной.
Сгенерированный звук на валидации с тем же спикером довольно быстро становится очень неплохим, робовойса вообще нет, артефакты еле слышны.
UPD: На 10k шаге сменили препроцессинг мелспеки на исправленную версию. Что пронаблюдали:
- Скорость падения потерь генератора (supervised и adversarial) увеличилась
- Потери дискриминаторов скакнули вверх
- Скорость роста валидационных метрик не увеличилась
Какие из этого можно сделать выводы
- Генератору стало легче обманывать дискриминаторы. Скорее всего до этого они отличали сгенерированное аудио от настоящего по сайд-эффектам паддинга, которые сильно уменьшились с новым препроцессингом
- Дискриминаторам теперь пришлось выучивать более сложные признаки
- Возможно, это повлияло на итоговое качество, но без запуска эксперимента со старым препроцессингом на столько же эпох узнать не получится
Результаты
Но нас, конечно, интересует инференс на незнакомом спикере, и тут все несколько похуже. Послушать примеры можно тут.
Слышен робовойс и артефакты. С другой стороны, впечатляет, что модель, которой нужно обучаться сотни часов вычислительного времени, может так быстро давать что-то вполне приемлимое. Все слова вполне можно разобрать. Заметны концевые артефакты, связанные с паддингом: первое слово тише последующих. Рано или поздно моделька должна обучиться это исправлять, потому что на обычной валидации такой проблемы нет.
UPD: После дообучения с новым препроцессингом результаты стали заметно лучше. Робовойс почти пропал, остались только артефакты в виде хрипоты. Концевые артефакты заметно уменьшились. По ссылке выше теперь расположена новая версия, в скрипте для загрузки данных тоже загружается она.
Дополнительно
А что будет, если попытаться снизить доменный сдвиг тестового аудио? Для этого в скрипт инференса добавили возможность менять высоту звука. При увеличении высоты, по идее, голос станет больше похож на женский. Результаты при повышении высоты звука на +7 steps доступны в той же папке гугл диск. В целом улучшение получилось, артефакты хрипоты пропали, но, конечно, не идеально.
Еще вот по приколу вокодер версия funkytown, полученная с помощью модели с 30 эпохи
Воспроизводимость
Все необходимые действия описаны в README.
Add a comment