Skip to main content

설명: 정규화 없이 학습한 ResNet에서 성능 격차를 줄이기 위한 신호 전파 특성화

BatchNorm은 거의 모든 최신 이미지 분류기에서 핵심 구성 요소이지만, 실무적으로는 여러 가지 과제를 야기합니다. 그렇다면 BatchNorm 없이도 유사한 성능을 낼 방법이 있을까요? 답은 “그렇다”로 보입니다! 이 글은 AI로 번역된 기사입니다. 오역이 의심되는 부분이 있다면 댓글로 알려 주세요.
Created on September 15|Last edited on September 15
이 보고서의 일부로, 우리는 “정규화 없는 ResNet에서 성능 격차를 줄이기 위한 신호 전파 특성화논문을 자세히 살펴봅니다.
이 보고서는 두 개의 주요 섹션으로 나뉩니다:
  1. 소개 - 논문의 핵심 아이디어와 기여에 대한 고수준 소개
  2. 코드 구현 - PyTorch로 구현한 코드와 함께 논문을 깊이 있게 이해하기 (코드 출처: timm)
이 보고서는 연구 논문을 쉽고 소화 가능한 단위로 나누어 설명하는 것을 목표로 하며, 상향식이 아닌 상위에서 하위로 내려오는 방식으로 전개합니다. 먼저 고수준의 개념을 파악한 뒤, 논문의 세부 내용으로 들어갑니다. 또한 초보자의 관점에서 논문을 설명하는 것을 목표로 하므로, 일부 섹션에서는 원 논문보다 다소 길어질 수 있습니다.
이 보고서의 대부분 섹션에서는 연구 논문의 원문을 직접 인용한 뒤, 제가 이해한 바를 바탕으로 핵심 아이디어를 더 쉽게 풀어 설명하기 위해 이를 패러프레이즈합니다.
💭: 이 보고서 전반에는 이런 사이드바가 곳곳에 등장합니다. 이는 여러분께 도움이 될 만한 저의 개인적인 코멘트입니다. 때로는 재미있는 에피소드나 실패한 실험담도 함께 담겨 있을 수 있습니다.
💭: 이 보고서는 길고, 일반적인 블로그 글보다 더 깁니다. 하지만 완결성을 갖추고 있습니다. 저는 이 보고서에서 어떤 개념도 빠뜨리지 않으려고 의식적으로 노력했습니다. 제 생각에는, 이 보고서는 여러 부분으로 나누어 읽는 것이 가장 좋으며, 처음 접하는 독자에게는 반복해 읽는 과정이 필요할 수 있습니다. NF-ResNets 처음으로
💭: 이 논문은 기존 연구를 바탕으로 하고 있으므로, 여기에 언급된 여러 개념에 대한 사전 지식이 어느 정도 필요합니다. 사전 지식 아래 섹션이 큰 도움이 될 것입니다. 이 보고서만으로는 이 논문을 충분히 설명하지 못했다고 느끼신다면, 과거 연구를 함께 살펴보며 빈틈을 메우는 논문 읽기 세션을 열어도 좋습니다. 원하시면 알려 주세요. 제가 준비하겠습니다 :)
💭: 제 개인적인 목표는 이 연구 논문을 간결하고 쉽게 읽을 수 있는 형식으로 정리한 보고서를 만드는 것입니다. 보고서의 일부가 이해되지 않거나 혼란스럽다면, 보고서 말미에 건설적인 피드백을 자유롭게 남겨 주세요.

사전 지식

이 보고서를 가장 효과적으로 활용하시려면, 독자께서 다음 내용에 대해 전반적인 이해를 갖추고 계시기를 권합니다. ResNet, 배치 정규화, ReLU 활성화 및 가중치 표준화다음 주제들에 대한 전반적인 소개를 빠르게 익히는 데 도움이 될 만한 자료를 아래에 소개합니다.
생각: 위에서 언급한 두 번째 자료에는 제가 조금 편파적입니다. 😉

소개

논문 서론에서
BatchNorm은 딥러닝�� 핵심 연산 프리미티브가 되었으며, 최첨단 이미지 분류기 대부분에서 사용된다. BatchNorm의 다양한 이점이 확인되어 왔다. 손실 지형을 매끈하게 만들어 더 큰 학습률로 학습할 수 있게 하고, 미니배치로 추정한 배치 통계에서 발생하는 노이즈가 암묵적 정규화를 제공한다. 그러나 BatchNorm에는 단점도 많다. 동작이 배치 크기에 크게 의존하여, 디바이스별 배치 크기가 너무 작거나 너무 클 때 성능이 저하되고, 학습 중 모델의 동작과 추론 시 동작 사이에 불일치를 야기한다. 여러 대체 정규화 레이어가 제안되었지만, 일반적으로 이러한 대안들은 일반화 성능이 떨어지거나 추론 시 추가 연산 비용 등 고유한 단점을 도입한다. 다른 연구 흐름에서는 은닉 활성화를 정규화하는 레이어를 아예 제거하는 방향을 모색해 왔다.
이 논문에서 저자들은 정규화 레이어 없이도 깊은 ResNet을 안정적으로 학습시키고, 최첨단에 견줄 만한 테스트 정확도를 달성하는 일반적인 학습 레시피를 제시하고자 한다! 배치 정규화(BatchNorm) 컴퓨터 비전에서 딥러닝 연구를 발전시키는 데 핵심적인 역할을 해왔지만, 최근 몇 년 사이에는 활성화를 정규화하는 레이어를 아예 제거하려는 새로운 연구 흐름이 등장했다.
질문: 왜 BatchNorm을 제거하고 싶을까요? 이에 대한 답은 다음에서 다룹니다 왜 Normalizer-Free 네트워크가 필요할까요? BatchNorm에는 무엇이 문제일까요? 이 보고서의 섹션
이 연구 논문은 이러한 연구 흐름을 따르며, 그 흐름의 핵심 기여 다음과 같습니다:
  • 신호 전파 플롯저자들은 실무자가 딥 레지듀얼 네트워크에서 초기화 시점의 순전파 동안 신호 전파를 점검할 수 있도록 돕는 간단한 시각화 도구 세트를 제안한다.
  • 스케일드 가중치 표준화: 저자들은 ReLU 또는 Swish 활성화와 가우시안 가중치를 사용하는 과거 비정규화 ResNet에서의 핵심 실패 모드를 지적한다. 이러한 비선형 함수들의 평균 출력이 양수이기 때문에, 네트워크 깊이가 증가할수록 각 채널의 은닉 활성화 평균의 제곱이 급격히 증가한다. 이를 해결하기 위해 저자들은 다음을 제안한다 스케일드 웨이트 스탠더다이제이션 의 확장판인 가중치 표준화본질적으로, 가중치 표준화는 합성곱 층의 가중치를 정규화하여 가중치의 평균을 0, 분산을 1로 맞추는 것이다.
  • BatchNorm과 견줄 만한 성능 상대 항목저자들은 Scaled Weight Standardization과 결합한 정규화 없는 네트워크 구조를 ImageNet의 ResNet에 적용하여, 최초로 288층에 달하는 매우 깊은 네트워크에서도 BatchNorm을 사용한 ResNet에 견주거나 그보다 나은 성능을 달성했다.
❓: 무엇을 의미하나요?이러한 비선형 함수들의 출력 평균은 양수이다위의 두 번째 항목에서 “…”는 무엇을 의미하나요? 기억하세요, ReLU는 다름 아닌 a max(0,x)max(0, x) 입력이 주어졌을 때 수행되는 연산 xx그래서 ReLU 활성화의 출력은 항상 양수이며, 그 결과 평균이 다음을 향해 이동합니다 mean(x)>0mean(x)>0이를 상쇄하기 위해, 저자들은 뒤에서 자세히 다룰 Scaled Weight Standardization을 도입했습니다 스케일드 가중치 표준화 보고서의 섹션
💭: 지금까지 BatchNorm 없이 구성된 네트워크는 동급의 최신 성능을 내지 못했기 때문에, 왜 이것이 상당히 흥미로운지 짐작하실 겁니다! 실제로, 그들의 후속 논문심지어 저자들은 새로운 최고 성능까지 달성합니다! 하지만, 지금은 너무 앞서 나가지 말죠..

왜 Normalizer-Free 네트워크가 필요할까요? BatchNorm에는 무엇이 문제일까요?

혹시 “BatchNorm이 뭐가 문제죠? 지금까지 거의 모든 네트워크에서 봤는데요…”라고 궁금하셨다면, 이 섹션에서 그 답을 찾아볼 수 있습니다.
기본적으로 BatchNorm은 매우 좋은 특성을 지니고 있지만 단점도 있습니다.
논문에서 언급된 주요 장점들은 다음과 같습니다:
BatchNorm의 이점으로 알려진 것들이 여럿 있습니다. 손실 지형을 더 매끄럽게 만들어 주며 (Santurkar 외, 2018), 더 큰 학습률로 학습할 수 있게 해 주며 (Bjorck 외, 2018) 그리고 미니배치로 추정한 배치 통계에서 발생하는 노이즈가 암묵적 정규화를 도입하며 (Luo 외, 2019). 또한 항등 스킵 연결을 가진 깊은 잔차 네트워크에서 초기화 시 신호 전파도 잘 유지합니다 (De & Smith, 2020).
하지만 BatchNorm에는 단점도 많습니다. 또한, 논문에서는 다음과 같이 말합니다:
동작이 배치 크기에 강하게 의존하며, 디바이스당 배치 크기가 너무 작거나 너무 클 때 성능이 저하됩니다 (Hoffer 외, 2017), 그리고 모델이 학습 중일 때와 추론 시점의 동작 사이에 불일치를 초래합니다. 또한 BatchNorm은 메모리 오버헤드도 증가시킵니다 (Rota Bulo 외, 2018), 그리고 구현 오류의 흔한 원인이기도 합니다 (Pham 외, 2019). 또한 서로 다른 하드웨어에서 학습된 배치 정규화 모델을 재현하기가 종종 어렵습니다.
따라서 이 연구 흐름은 다음과 같은 논리를 따릅니다 - "배치 정규화의 장점은 유지하면서 단점을 없앤 정상화 없는 네트워크를 찾을 수 있다면, 더 작은 배치 크기로도 학습할 수 있고, 학습과 추론 속도를 높이며 메모리 오버헤드까지 줄일 수 있습니다!"""
또한 NF 네트워크 전반에서 신호가 잘 전파되기를 바랍니다. 그렇다면 신호 전파를 측정할 방법이 있을까요? NF 네트워크를 BatchNorm을 사용하는 대응 모델과 어떻게 비교할 수 있을까요? 여기서 신호 전파 플롯이 등장합니다.

신호 전파 플롯

생각: 이 부분은 제가 논문에서 가장 좋아하는 섹션 중 하나입니다. 기본적으로 신호 전파 플롯은 네트워크 내부의 “신호 전파”를 측정하는 데 도움을 주는 플롯입니다. 방법은 간단합니다. 단 한 번의 순전파 동안 네트워크 내부의 여러 지점에서 통계를 계산하고, 그 값을 그래프로 그립니다.
논문에서:
최근 논문들이 ResNet의 신호 전파를 이론적으로 분석하긴 했지만, 실제로는 새로운 모델을 설계하거나 기존 아키텍처를 수정할 때 특정 심층 네트워크 내부의 서로 다른 깊이에서 히든 활성값의 스케일을 경험적으로 평가하는 일이 드뭅니다. 이에 비해, 네트워크 내부의 여러 지점에서 히든 활성값의 통계를 그려 보고, 입력 배치를 무작위 가우시안 분포 샘플이나 실제 학습 예시로 조건화해 비교하는 방식이 매우 유익하다는 것을 우리는 발견했습니다.
저자들은 네트워크 내부의 서로 다른 지점에서 히든 활성값의 통계를 그려보는 것이 유용하다는 사실을 발견했고, 이러한 플롯을 다음과 같이 명명했습니다 신호 전파 플롯.
생각: 이상적으로는 네트워크 전반에서 히든 활성값의 평균이 0이고 분산이 1이 되도록 유지하는 것이 좋습니다. 이는 “양호한” 신호 전파를 판단하는 좋은 기준입니다.
저자들은 차원이 다음과 같이 표기된 4차원 입력 및 출력 텐서를 고려합니다 NHWC, 어디에서 배치 차원을 나타내며, C 채널을 나타내며, H 그리고 더블유 두 개의 공간 차원인 높이와 너비를 나타낸다.
���자들은 다음과 같은 형태의 항등 잔차 블록을 가정한다:
xL+1=fL(xL)+xLx_{L+1} = f_{L}(x_{L}) + x_{L}
여기서, xlx_l 를 나타낸다 lthl_{th} 블록, flf_l 가 계산하는 함수를 나타낸다 lthl_{th} 잔차 분기.
😕: 수식이 헷갈리시나요? 이 수식은 다음에 나오는 사전 활성화 ResNet 블록을 나타냅니다. 연구 논문의 그림 4(a). 여기 f(.)f(.) 다음과 같이 BatchNorm, ReLU, Conv 연산으로 구성된 잔차 분기에서의 함수를 나타낸다 그림 4(a)이는 다음 섹션들에서 잔차 블록이라고도 부른다.
그 다음, 신호 전파 플롯(SPP)을 생성하기 위해 저자들은 초기화 방식에 따라 네트워크를 초기화한다(이는 He 초기화, 또는 Glorot 초기화, 또는 다른 어떤 방식이든), 그리고 단위 가우시안 분포에서 샘플링한 입력 예제들의 배치를 네트워크에 제공한다.
생각: 간단히 말해, 원하는 네트워크를 선택해 적절한 초기화 방식으로 초기화한 다음, 평균이 0이고 분산이 1인 가우시안 입력에 대해 한 번의 포워드 패스를 수행하라.
이 입력이 네트워크를 따라 전파되면서, 저자들은 각 잔차 블록의 끝에서 다음 통계를 플로팅한다:
  • 평균 채널 제곱 평균NHW 축 전반의 평균을 제곱한 값으로 계산하고, 그 후 C 축을 따라 평균낸 값.
  • 평균 채널 분산, NHW 축 전반에서 채널 분산을 계산한 뒤, C 축을 따라 평균낸 값.
  • 잔차 분기 끝에서의 평균 채널 분산, 스킵 경로와 병합하기 전에.
생각: 잠깐 멈추고 한숨 고르며, 눈을 감고 SPP가 어떤 느낌일지 떠올려 보는 건 어떨까요? 어떻게 생겼을 것 같나요? 여러분의 기대는 어떤가요?
아래에서 그림 1, 우리는 …의 예를 볼 수 있습니다 신호 전파 플롯 신호 전파를 내부에서 측정하는 ResNet-V2 600 네트워크를 대상으로 위에서 언급한 세 가지 통계를 두 개의 네트워크에 대해 플로팅합니다. 하나는…를 사용한 경우이고, 다른 하나는…를 사용하지 않은 경우입니다. BN–ReLU–Conv 잔차 블록 내부의 레이어 순서와, 다른 하나는 ReLU–BN–Conv 순서…
❓: BN–ReLU–Conv 레이어 순서와 ReLU–BN–Conv 순서는 무엇이 다른가요? BN과 ReLU 활성화 레이어의 위치를 서로 바꾸면 됩니다. 연구 논문의 그림 4(a).
🤔: 플롯이 기대와 일치하나요? 아니라면, 그 이유는 무엇인가요?
그림 1: BatchNorm, ReLU 활성화, He 초기화를 사용한 ResNetV2-600을 512px 해상도에서 평균 0, 분산 1의 가우시안 입력으로 초기화했을 때의 신호 전파 플롯. 검은 점은 각 스테이지의 끝을 나타냅니다. 파란 플롯은 BN–ReLU–Conv 순서를, 빨간 플롯은 ReLU–BN–Conv 순서를 사용합니다.
기억하신다면 ResNet 논문각 ResNet 모델은 네 개의 스테이지로 나뉘며, 각 스테이지는 서로 다른 개수의 블록을 갖습니다. ResNet-V2의 경우 총 200개의 잔차 블록이 있으며, 각 스테이지에는 50개의 블록이 있습니다.
💭: 이 모델은 저자들이 제공한 코드에서 정의되었습니다 여기.
ResNet-V2 600에는 잔차 블록이 200개 있으므로, 그 결과가 그림 1입니다. 각 잔차 블록마다 이 세 가지 통계값을 계산하여 플로팅합니다.
💭: 그림 1은 한동안 정말 혼란스러웠습니다. 저는 어쩐지 ResNet-V2 600이 잔차 블록을 600개 갖고 있다고 생각했고, 그래서 X축의 최댓값이 200인 것을 보고 더 헷갈렸죠. 그런데 코드를 확인하고 나니 이제 이해가 됩니다. 여러분께는 저만큼 혼란스럽지 않기를 바랍니다.

핵심 관찰 사항

⚠️: 여기에는 특히 주의하세요.
정규화기 없는 신경망을 살펴볼 때 큰 도움이 될, 그림 1에서 드러난 핵심 패턴:
  1. 그림 1(b)에서, 평균 채널 분산은 깊이에 따라 선형적으로 증가한다 해당 스테이지에서, 그리고 각 전이 블록에서 초기화됩니다.
  2. 그림 1(a)에서, BN–ReLU–Conv 순서의 경우, 평균 제곱 채널 평균도 깊이에 따라 선형적으로 증가하는 유사한 거동을 보인다.
💭: 호기심 많은 독자분들께는 …을 참고하시길 권합니다 SkipInit 연구 논문 같은 저자들이 쓴 다른 논문으로, 초기화 시 잔차 분기에서 신호를 감쇠시키는 스칼라로 BatchNorm을 대체한 연구입니다.
💭: 이러한 패턴을 Normalizer-Free ResNet에서도 모방할 수 있다면, 정규화를 쓰지 않는 새로운 네트워크도 잘 학습되고 정규화된 대응 모델들과 견줄 만한 성능을 내도록 보장할 수 있습니다.

PyTorch에서 사용자 정의 모델의 신호 전파 플롯 그리기

좋아요, 여기까지입니다 — 이것이… 신호 전파 플롯 이론적으로는 그렇습니다. 우리가 직접 이런 플롯을 만들어 재현해 볼 수 있다면 재미있지 않을까요? 사실 가능합니다. 제가 가장 좋아하는 라이브러리 중 하나를 사용하면 됩니다 — timm - 바로 이것을 하기 위해!
import torchvision
from timm.utils.model import extract_spp_stats, avg_ch_var, avg_ch_var_residual, avg_sq_ch_mean

model = torchvision.models.resnet50()
spp_stats = extract_spp_stats(m,
hook_fn_locs=['layer?.?', 'layer?.?', 'layer?.?.bn3'],
hook_fns=[avg_sq_ch_mean, avg_ch_var, avg_ch_var_residual])

# plot stats
fig, ax = plt.subplots(1, 3, figsize=(18,3), sharey=True)
ax[0].plot(stats['avg_sq_ch_mean'], label='avg_sq_ch_mean');
ax[0].legend(); ax[0].grid();
ax[1].plot(stats['avg_ch_var'], label='avg_ch_var');
ax[1].legend(); ax[1].grid();
ax[2].plot(stats['avg_ch_var_residual'], label='avg_ch_var_residual');
ax[2].legend(); ax[2].grid();
그림 2: ResNet-50의 BN–ReLU–Conv 순서를 사용한 신호 전파 플롯 timm
생각: 저는 SPP를 기여할 수 있는 행운을 얻었습니다 timm. ResNet V2 600에 대해 그림 1을 재현할 수 있는 전체 노트북은 다음을 참조하세요 여기.

정규화 없는 ResNet

생각: 아래 섹션은 위의 내용보다 조금 더 복��하게 느껴질 수 있지만, 바로 여기에서 Normalizer-Free ResNet이 소개되기 때문에 가장 중요한 부분이기도 합니다. 이해가 어려우면 이 섹션을 다시 읽어 보시거나, 궁금한 점이 있으면 언제든지 제게 문의하세요.
이제 BatchNorm으로 정규화된 네트워크에 대해 충분히 탄탄한 이해를 갖췄고, 분석을 도와줄 SPP도 준비됐으니, 정규화를 쓰지 않으면서도 신호 전파가 좋고 학습이 안정적이며, 배치 정규화된 모델에 필적하는 테스트 정확도에 도달하는 ResNet의 변형들을 살펴볼 준비가 되었습니다.
생각: 사실 우리가 필요한 것은 BatchNorm으로 정규화된 네트워크와 SPP에 대한 충분한 이해뿐이었습니다. 본질적으로 저자들은 정규화된 ResNet의 SPP를 모방하는 SPP를 가진 정규화기 없는 네트워크를 만듭니다. 왜일까요? 이에 대한 답은 이미 다음에서 다루었습니다 질의응답 섹션 아래
두 가지가 있다 핵심 관찰정규화기 없는 네트워크를 설계할 때 반드시 기억하고 동일한 효과를 재현해야 하는 사항들:
  1. BatchNorm은 잔차 블록의 입력을 축소한다 입력 신호의 표준편차에 비례하는 계수만큼.
  2. 각 잔차 블록은 신호의 분산을 증가시킨다 대략 일정한 계수만큼. (평균 채널 분산이 선형적으로 증가한다)
저자들은 다음과 같은 형태의 새로운 네트워크를 설계하여 이러한 효과를 모방할 것을 제안했다:
xl+1=xl+αfl(xlβl)x_{l+1} = x_{l} + α f_{l}(\frac{x_l}{β_l})
여기서, xlx_l 를 나타낸다 lthl_{th} 잔차 분기와 fl(.)f_l(.) 를 나타낸다 lthl_{th} 잔차 분기.
메모: 이는 잔차 블록과는 대조적이라는 점에 유의하라 xL+1=fL(xL)+xLx_{L+1} = f_{L}(x_{L}) + x_{L}
저자들은 다음과 같은 방식으로 새로운 정규화기 없는 네트워크를 설계했다:
주의: 여기는 특히 신경 써서 읽으세요.
  • fl(.)f_l(.), 초기화 시 잔차 분기가 계산하는 함수가 분산 보존이 되도록 매개변수화한다. 즉, Var(f(xl))=Var(xl)Var(f(x_l)) = Var(x_l).
  • βlβ_l 는 다음과 같이 선택된 스칼라 함수이다 var(xl)\sqrt{var(x_l)}. 이는 다음을 보장한다 fl(.)f_l(.) 분산이 1이다.
  • αα 는 블록 사이에서 분산이 증가하는 속도를 제어하는 스칼라 하이퍼파라미터이다.
메모: 네, 정보는 많은데 설명은 훨씬 적다는 걸 알고 있어요. 설명 다음은 문답 형식으로 이어집니다.

문답: 정규화기 없는 ResNet

메모: 이 절에서는 제가 처음 이 논문을 읽을 때 매우 혼란스러웠던 몇 가지 질문에 답해 보겠습니다.
❓ …라고 말할 때 무슨 뜻인가요? fl(.)f_l(.) 분산을 보존한다는 뜻인가요?
기본적으로 이는 다음을 의미합니다 fl(.)f_l(.) 입력의 분산을 바꾸지 않는다는 뜻입니다. 즉,
Var(fl(x))=Var(xl)Var(f_l(x)) = Var(x_l)
❓ 왜 우리가 원하는가요? fl(.)f_l(.) 분산을 보존한다는 것?
논문에서:
이 제약을 통해 네트워크에서 신호가 증가하는 양상을 논리적으로 파악하고, 분산을 해석적으로 추정할 수 있습니다.
그리고 분산을 해석적으로 계산할 수 있으므로, 이는 곧 다음의 값들도 계산할 수 있다는 뜻입니다. βlβ_l 해석적으로
❓ 왜 정규화가 없는 네트워크가 ResNet의 SPP 추세를 모방해야 할까요?
에서 오픈리뷰, 이에 대한 답변은 아래에서 저자들이 제공합니다:
정규화가 없는 네트워크에서 신호가 어떻게 전파되어야 하는지는 본질적으로 설계의 문제입니다. 부록 G.2에서, 우리는 처음에 분산을 일정하게 유지하도록 네트워크를 설계하는 방식을 탐색했다고 언급합니다. 사전 지식이 없다면 이것이 더 우수한 선택이라고 가정할 수도 있습니다. 그러나 우리는 이러한 네트워크의 성능이 기대만큼 좋지 않다는 것을 발견했고, 이미 효과가 검증된 신호 전파 템플릿을 모방하는 것이 더 나은 설계 선택이라고 판단했습니다. 이는 우리의 실험 결과로도 뒷받침됩니다.
❓ 왜인가요? βlβ_l 으로 선택된 var(xl)\sqrt{var(x_l)}?
var(xl)\sqrt{var(x_l)} 입력 신호의 표준편차입니다 xlx_l 에게 lthl_{th} 잔차 블록. 신호를 표준편차로 나누면 다음을 보장할 수 있습니다 xlx_l 단위 분산을 가지므로 안정적인 학습에 바람직합니다!
❓ 무엇인가요? αα?
새로운 정규화기-프리 블록의 설계에서 알다시피,
xl+1=xl+αfl(xlβl)x_{l+1} = x_{l} + α f_{l}(\frac{x_l}{β_l})
그러므로, βlβ_l 입니다 으로 선택된 var(xl)\sqrt{var(x_l)}, 그러므로, xlβl\frac{x_l}{β_l} 단위 분산을 가진다. 그러므로, fl(.)f_l(.) 분산을 보존하므로, fl(xlβl)f_l(\frac{x_l}{β_l}) 또한 단위 분산을 가진다.
이제, 분산을 계산하면, xl+1=xl+αfl(xlβl)x_{l+1} = x_{l} + α f_{l}(\frac{x_l}{β_l}) 다음을 얻는다:
Var(xl+1)=Var(xl)+Var(α.fl(xlβl)Var(x_{l+1}) = Var(x_l) + Var(α.f_l(\frac{x_l}{β_l})
그러므로,
Var(xl+1)=Var(xl)+α2Var(x_{l+1}) = Var(x_l) + α^2
💭: 여기까지 오고 이해하는 데 2주가 걸렸다 Var(xl+1)=Var(xl)+α2Var(x_{l+1}) = Var(x_l) + α^2, 그래서 처음에 이해하지 못해도 괜찮다. 이 부분이 이해되지 않으면 알려줘! NF-ResNets를 완전히 이해하는 데는 중요하지만, 구현 관점에서는 그다지 중요하지는 않다.
❓ 그 둘은 어떻게 되나요? 핵심 관찰s? 정규화 없이 사용하는 ResNet(NF-ResNet)은 이것들을 어떻게 처리하나요?
💭: 이것은 논문에 대한 나의 이해를 바탕으로 한 것이며, 다른 곳에서 명시적으로 설명된 바는 없다.
그럼, 우리가 정규화 없이도 재현하고 싶은 BatchNorm의 역할은 두 가지다:
  1. 잔차 분기의 입력을 다운스케일링하기
  2. 각 잔차 블록마다 신호의 분산을 대략 일정한 배수로 증가시킨다
그러므로, Var(xl+1)=Var(xl)+α2Var(x_{l+1}) = Var(x_l) + α^2그러므로 α는 블록 사이에서 분산이 증가하는 속도를 제어하는 스칼라 하이퍼파라미터다. 따라서 각 잔차 블록이 신호의 분산을 대략 일정한 배수로 증가시키도록 보장한다.
또한 잔차 분기의 입력은 xlβl\frac{x_l}{β_l}, 즉 다운스케일링된 입력이며 배수는 βlβ_l.
💭: 위에서 Normalizer Free 네트워크의 원리와 배경을 최대한 쉽게 풀어 설명해 보았습니다. 이 섹션에서 이해되지 않는 부분이 있었다면, 보고서 끝부분에서 언제든지 알려 주세요. :)

PyTorch/Jax의 NF-ResNet

로스 와이트먼 제가 좋아하는 라이브러리 중 하나에 이미 NF-ResNet을 구현해 두는 훌륭한 일을 해주었습니다 — timm! 이제 NF-ResNet을 만드는 일은 다음과 같이 간단합니다:
import timm
import torch

m = timm.create_model('nf_resnet50')
x = torch.randn(1, 3, 224, 224)
m(x).shape

>> torch.Size([1, 1000])
위에서는 다음을 사용해 NF-ResNet-50 모델을 간단히 생성합니다 timm 임의의 입력을 넣어 분류 출력을 얻을 수 있습니다. 이 네트워크는 사용자 정의 데이터셋에도 활용해 파인튜닝할 수 있습니다. 동일한 학습 스크립트와 파인튜닝 절차를 그대로 따라 하면 됩니다. 여기 timm 문서에서
생각: 저는 정말 좋아합니다 timm이는 가장 빠르게 성장하는 라이브러리 중 하나이며, 다음에 의해 지속적으로 최신 상태로 유지됩니다 로스. 최신 연구 논문은 결국 timm 정말 정말 빠르게! 저도 운 좋게 함께 일할 기회가 있었습니다 timm 문서 에 대한 보다 심층적인 문서를 제공하는 프로젝트 timm.
생각: PyTorch에서 NF-ResNet을 실험해 보고 싶다면, 사용하기 위해 timm 이러한 네트워크를 시작하는 가장 쉬운 방법 중 하나입니다. 또한 훌륭한 자료로는 아유시 타쿠르 를 사용하며 PyTorch Lightning을 기반으로 합니다 timm 배치 크기를 빠르게 실험해 보려는 경우.
코드에서 NF-ResNet을 빠르게 시작하는 간편한 방법을 살펴봤으니, 이제 소스 코드를 살펴보겠습니다. 이러한 네트워크의 timm 처음부터 이러한 네트워크를 어떻게 만들 수 있는지 이해하기 위해서입니다.
💭: 우리는 또한 …을 살펴볼 수도 있었어요 공식 코드 구현 DeepMind의 JAX 구현도 있지만, 단순함을 위해 여기서는 PyTorch로 진행하겠습니다. 우리는 다음의 소스 코드를 사용할 예정입니다 timm.
이제 Normalizer Free 블록을 만들기 위해, 우리는 다음 식을 다시 구성해야 합니다.
xl+1=xl+αfl(xlβl)x_{l+1} = x_{l} + α f_{l}(\frac{x_l}{β_l})
아래에서 PyTorch로 해봅시다:
import torch
import torch.nn as nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import build_model_with_cfg
from timm.models.layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\
get_act_layer, get_act_fn, get_attn, make_divisible

class NormFreeBlock(nn.Module):
"""Normalization-Free pre-activation block."""

def __init__(
self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None,
alpha=1.0, beta=1.0, bottle_ratio=0.25, group_size=None, ch_div=1, reg=True, extra_conv=False,
skipinit=False, attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0.):
super().__init__()
out_chs = out_chs or in_chs
mid_chs = int(in_chs * bottle_ratio)
self.alpha = alpha
self.beta = beta

if in_chs != out_chs or stride != 1 or dilation != first_dilation:
self.downsample = DownsampleAvg(
in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, conv_layer=conv_layer)
else:
self.downsample = None

self.act1 = act_layer()
self.conv1 = conv_layer(in_chs, mid_chs, 1)
self.act2 = act_layer(inplace=True)
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
self.act3 = act_layer()
self.conv3 = conv_layer(mid_chs, out_chs, 1, gain_init=1. if skipinit else 0.)

def forward(self, x):
out = self.act1(x) * self.beta

# shortcut branch
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(out)

# residual branch
out = self.conv1(out)
out = self.conv2(self.act2(out))
out = self.conv3(self.act3(out))
out = out * self.alpha + shortcut
return out
위의 코드는 아래 그림에 나온 Pre-Activation ResNet 블록과 유사한 Normalizer Free 병목 블록을 재구성합니다. 하지만 BatchNorm 없이:

그림 3: 사전 활성화 ResNet을 위한 잔차 블록 (He 외, 2016a).
���: 여기서는 독자에게 다음 내용을 직접 매핑해 보도록 과제로 남겨 둡니다. 그림 3 - (a)와 (b) 의 코드 구현으로부터 timm 위에서 공유한 Normalizer Free 블록에 해당합니다. 이 코드 스니펫은 Pre-Activation ResNet 블록과 Pre-Activation ResNet 전환 블록을 모두 구현할 수 있습니다.
💭: 또한 Nf-ResNet을 만들려면, 언급된 다양한 ResNet 구성에 따라 블록을 반복하면 되는 간단한 작업입니다. 여기. 이 부분도 과제로 남겨 두지만, 필요하다면 더 깊게 파고들 준비가 되어 있습니다. 이를 과제로 둔 이유는, 제공된 소스 코드를 사용해 네트워크를 직접 재구성해 보면서 NF-ResNet을 이해하는 데 정말 좋은 연습/프로젝트가 되기 때문입니다. timm 안내를 위해.
💭: 전체 구현을 이해하는 데 필요한 모든 요소는 이미 이 보고서에 모두 공유되어 있습니다.

스케일드 가중치 표준화

💭: 아직 풀어야 할 미스터리가 남아 있지만, 이제 마지막 하나만 남았습니다.
그림 4: ReLU 활성화를 사용하는 ResNetV2-600 네트워크의 세 가지 변형에 대한 SPP. 빨간색은 ReLU–BN–Conv 순서의 BatchNorm 적용 네트워크. 초록색은 He 초기화와 α = 1을 사용하는 Normalizer-Free 네트워크. 청록색은 동일한 Normalizer-Free 네트워크에 Scaled Weight Standardization을 적용한 경우.
저자들은 위에서 소개한 NF-ResNet 네트워크를 구현하고 이를 비교했다 신호 전파 플롯 ReLU–BN–Conv 순서의 정규화된 네트워크와 함께(이미 앞에서 본) 그림 1). NF-ResNet은 다음과 같이 초기화했다 He 초기화 그리고 α = 1. 두 SPP는 …에 표시되어 있습니다 그림 4
보시다시피 그림 4, 두 SPP는 동일하지 않습니다. 사실, 다음과 같은 두 가지 예상치 못한 패턴을 관찰할 수 있습니다:
  1. NF-ResNet의 경우, 채널 평균의 제곱값의 평균이 깊이가 깊어질수록 빠르게 증가하여, 평균 채널 분산을 초과할 정도로 큰 값에 도달합니다. (그림 5(a)와 그림 5(b) 비교)
  2. NF-ResNet의 경우, 잔차 분기의 경험적 분산 규모는 일관되게 1보다 작습니다. (그림 5(c) 녹색)
💭: 앞서 말씀드렸듯이, 네트워크 전반에서 활성값의 평균이 0이고 분산이 1인 상태가 바람직한 신호 전파로 간주됩니다. 따라서 깊이가 증가함에 따라 채널 평균의 제곱값 평균이 급격히 커지고, 분산 값이 일관되게 1보다 작은 경우는 네트워크의 불안정을 의미합니다.
평균 이동의 발생을 막고 잔차 분기가 fl(.)f_l(.) 분산을 보존하도록, 저자들은 다음을 제안했습니다 스케일드 가중치 표준화.
💭: 스케일드 가중치 표준화는 다음의 확장입니다 가중치 표준화. 본질적으로 가중치 표준화는 합성곱 층의 가중치를 정규화하여, 가중치의 평균을 0, 분산을 1이 되게 만듭니다. 이는 은닉 활성값을 표준화하는 Batch Normalization(이하 BatchNorm)과는 다르다는 점에 유의하세요. 더 궁금하신 분들을 위해, 여기 Yannic Kilcher가 제작한 훌륭한 영상으로, 가중치 표준화를 설명합니다.
스케일드 가중치 표준화 다음과 같이 정식화되었습니다:
Wi,j^=γWi,jμwiσwi.N\hat{W_{i,j}} = γ \frac{W_{i,j}-μ_{w_i}}{σ_{w_i}.\sqrt{N}}
여기서 평균 μ와 분산 σ는 합성곱 필터의 팬인 범위 전반에 걸쳐 계산됩니다. 저자들은 기본 매개변수를 다음과 같이 초기화합니다 WW가우시안 가중치에서, γ는 고정 상수입니다.
생각: 혼란스럽게 느껴지더라도 걱정하지 마세요. 아래에 공유한 코드 구현을 보면 훨씬 명확해집니다. 핵심은 모든 CNN의 가중치에 대해, 가중치의 평균을 빼고 표준편차로 나누어 표준화한다는 것입니다.
생각: 스케일드 가중치 표준화와 가중치 표준화의 ���일한 차이는 비선형성에 의존하는 고정 상수 γ를 도입했다는 점입니다. 더 궁금하신 분들은 4.2절을 참조하세요 논문 그 이유를 설명합니다.
저자들이 Normalizer-Free ResNet에 스케일드 가중치 표준화를 적용했을 때, 다음에 보이는 것처럼 그림 4, 스케일드 웨이트 스탠더다이제이션 초기화 시 채널 평균 제곱값의 증가를 제거합니다. 실제로 Scaled Weight Standardization을 적용한 네트워크의 SPP는 BatchNorm을 사용하는 네트워크의 SPP와 거의 동일합니다 ReLU–BN–Conv 정렬 순서(빨간색으로 표시)
💭: 저는 PyTorch로 Weight Standardization 구현을 가지고 있습니다 여기.
❓ 무엇인가요? 감마인가요?
γ는 비선형 함수에 따라 달라지는 고정 상수이며, 여러 비선형 함수에 대해 다음과 같은 값을 갖습니다. 비선형 함수 g에 의존하는 γ의 값은 해당 층이 분산을 보존하도록 선택됩니다.
# from https://github.com/deepmind/deepmind-research/tree/master/nfnets
_nonlin_gamma = dict(
identity=1.0,
celu=1.270926833152771,
elu=1.2716004848480225,
gelu=1.7015043497085571,
leaky_relu=1.70590341091156,
log_sigmoid=1.9193484783172607,
log_softmax=1.0002083778381348,
relu=1.7139588594436646,
relu6=1.7131484746932983,
selu=1.0008515119552612,
sigmoid=4.803835391998291,
silu=1.7881293296813965,
softsign=2.338853120803833,
softplus=1.9203323125839233,
tanh=1.5939117670059204,
)
이게 전부입니다! 스케일드 웨이트 스탠더다이제이션 정규화가 없는 네트워크가 정규화된 네트워크와 견줄 만한 성능을 내기 위해 마지막으로 맞춰 넣어야 했던 퍼즐 조각이었습니다!

PyTorch에서의 스케일드 웨이트 스탠더다이제이션

다음 구현을 참고합니다 timm 다시 한 번, 아래와 같이 스케일드 웨이트 스탠더다이제이션에 대해:
class ScaledStdConv2d(nn.Conv2d):
"""Conv2d layer with Scaled Weight Standardization.

Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
https://arxiv.org/abs/2101.08692
"""
def __init__(
self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
bias=True, gamma=1.0, eps=1e-5, gain_init=1.0):
if padding is None:
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias)
self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
self.eps = eps

def get_weight(self):
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = self.scale * (self.weight - mean) / (std + self.eps)
return self.gain * weight
다시 한 번, 위 구현을 공식과 맞춰보는 과제는 여러분께 맡기겠습니다. 물론 혼란스러운 부분이 있다면 언제든지 저에게 문의해 주세요.

결론

이 보고서를 통해 독자 여러분께 NF-ResNet을 설명하고, 이 네트워크를 코드로 어떻게 구현할 수 있는지도 보여드릴 수 있었기를 바랍니다. 지난 몇 주 동안 셀 수 없을 만큼 많은 시간을 들여 NF-ResNet을 직접 읽고 이해한 뒤, 그 내용을 최대한 간결하고 읽기 쉬운 보고서로 정제하려고 노력했습니다. 최종본에 만족할 때까지 여러 차례 처음부터 다시 쓰며 수정을 거듭했습니다. 제가 이 보고서에 쏟은 노력이 사랑하는 독자 여러분께 실제로 도움이 되기를 진심으로 바랍니다.
이 보고서가 다소 수학적인 내용이 많다는 점은 알고 있습니다. 다만, 식을 부분별로 나누어 설명하면서 최대한 쉽게 풀어쓰려고 노력했습니다. Normalizer-free 네트워크를 가장 잘 설명하려면 수학적 내용이 꼭 필요했기 때문입니다.
제 글에서 늘 그렇듯 블로그 게시물, 언제든지 연락 주세요 트위터 또는 제가 놓친 부분이 있다면 아래 댓글로 건설적인 피드백을 남겨 주세요.
읽어 주셔서 감사합니다. 즐거운 실험 되세요!

이 글은 AI로 번역된 기사입니다. 오역이 있을 수 있으니 댓글로 알려 주세요. 원문 보고서는 다음 링크에서 확인하실 수 있습니다: 원문 보고서 보기