해설: 정규화 없이 작동하는 ResNet의 성능 격차를 해소하기 위한 신호 전파 특성화
BatchNorm은 거의 모든 SOTA 이미지 분류기에서 핵심 구성 요소이지만, 실무적으로 여러 도전 과제도 함께 가져옵니다. 그렇다면 BatchNorm 없이도 비슷한 성능을 낼 수 있는 방법이 있을까요? 답은 “그렇다!”인 듯합니다. 이 글은 AI 번역 기사입니다. 오역이 있을 수 있으니 댓글로 알려주세요.
Created on September 15|Last edited on September 15
Comment
이 보고서는 두 개의 주요 섹션으로 구성되어 있습니다.
- 소개 논문의 핵심 아이디어와 기여에 대한 고수준 소개
이 보고서는 연구 논문을 단순하고 소화하기 쉬운 구성으로 풀어 설명하며, 톱다운 접근을 따릅니다. 먼저 높은 수준에서 전반적인 개념을 이해한 뒤, 논문의 세부 사항을 살펴봅니다. 또한 초보자의 관점에서 논문을 설명하는 것을 목표로 하기에, 일부 섹션에서는 원 논문보다 다소 길어질 수 있습니다.
이 보고서의 대부분 섹션에서는 연구 논문의 문장을 직접 인용한 뒤, 필자의 이해를 바탕으로 핵심 아이디어를 더 쉬운 언어로 풀어 설명합니다.
💭: 이 보고서 곳곳에는 이런 사이드바가 등장합니다. 이는 독자에게 도움이 될 만한 필자의 개인적인 코멘트입니다. 때로는 재미있는 요소나 실패한 실험담도 함께 담겨 있을 수 있습니다.
💭: 이 보고서는 길고, 일반적인 블로그 글보다 더 길지만, 내용 면에서는 완성형입니다. 어떤 개념도 빠뜨리지 않으려고 의도적으로 신경 썼습니다. 개인적인 의견으로는, 이 보고서는 나눠서 읽는 것이 가장 좋으며, 처음 접하는 독자에게는 재독이 필요할 수 있습니다 NF-ResNet 처음으로
💭: 이 논문은 기존 연구를 토대로 하고 있으므로, 여기에 언급된 다양한 개념에 대한 사전 지식이 어느 정도 필요합니다 사전 지식 이번 섹션이 큰 도움이 될 것입니다. 이 보고서만으로는 이 연구 논문을 충분히 설명하지 못했다고 느끼신다면, 과거 연구를 함께 살펴보며 빈틈을 메우는 논문 리딩 세션을 열어도 좋습니다. 원하시면 말씀해 주세요. 바로 준비하겠습니다 :)
💭: 제 개인적인 목표는 이 연구 논문을 누구나 쉽게 읽고 이해할 수 있는 형식으로 압축한 보고서를 만드는 것입니다. 만약 보고서의 일부가 이해되지 않거나 혼란스럽다면, 보고서 마지막에 건설적인 피드백을 자유롭게 남겨 주세요.
사전 지식
이 보고서를 가장 효과적으로 활용하려면, 독자는 다음에 대해 전반적인 이해를 갖추고 있는 것이 좋습니다 ResNet, 배치 정규화, ReLU 활성화 & 웨이트 스탠더드라이제이션다음 주제들에 대한 전반적인 소개를 빠르게 익히는 데 도움이 될 만한 자료들을 소개합니다:
💭: 위에서 언급한 두 번째 자료에는 제가 좀 편향적일 수 있어요. ;)
소개
논문의 Introduction 섹션에서:
BatchNorm은 딥러닝의 핵심 계산 원시 연산으로 자리 잡았으며, 최첨단 이미지 분류기 대부분에서 사용됩니다. BatchNorm의 다양한 이점도 확인되어 왔습니다. 손실 지형을 더 매끄럽게 만들어 더 큰 학습률로 학습할 수 있게 하고, 배치 통계를 미니배치로 추정하면서 생기는 노이즈가 암묵적 규제로 작용합니다. 그러나 BatchNorm에는 단점도 많습니다. 동작이 배치 크기에 크게 의존하여, 디바이스당 배치 크기가 너무 작거나 너무 클 때 성능이 저하되며, 학습 시점과 추론 시점의 모델 동작 사이에 불일치를 야기합니다. 여러 대체 정규화 레이어가 제안되었지만, 일반적으로 이러한 대안은 일반화 성능이 낮거나 추론 시 추가 연산 비용과 같은 고유한 단점을 도입합니다. 또 다른 연구 흐름은 은닉 활성값을 정규화하는 레이어 자체를 완전히 제거하려는 시도를 하고 있습니다.
이 논문에서 저자들은 정규화 레이어 없이도 최첨단에 견줄 만한 테스트 정확도를 달성하는 심층 ResNet을 학습하기 위한 일반적인 방법론을 정립하고자 합니다! 배치 정규화(BatchNorm) 은 컴퓨터 비전 분야의 딥러닝 연구를 발전시키는 데 핵심적인 역할을 해 왔지만, 최근 몇 년 사이에는 활성값을 정규화하는 레이어를 아예 제거하려는 새로운 연구 흐름이 등장했습니다.
❓: 왜 BatchNorm을 제거하고 싶을까요? 이에 대한 답변은 이미 왜 노멀라이저 프리 네트워크가 필요할까요? BatchNorm에는 무엇이 문제일까요? 이 보고서의 섹션
이 연구 논문은 이러한 연구 흐름을 따르며 이를 핵심 기여 입니다:
- 신호 전파 플롯저자들은 딥 레지듀얼 네트워크에서 초기화 시 순전파 단계의 신호 전파를 점검할 수 있도록, 실무자가 활용하기 쉬운 간단한 시각화 도구 모음을 제안합니다.
- 스케일드 웨이트 스탠더드라이제이션: 저자들은 ReLU 또는 Swish 활성화와 가우시안 가중치를 사용하는 과거 비정규화 ResNet에서의 핵심 실패 양상을 지적합니다. 이러한 비선형 함수들의 출력 평균이 양수이기 때문에, 각 채널의 은닉 활성에 대한 제곱 평균이 네트워크 깊이가 증가할수록 급격히 커집니다. 이를 해결하기 위해 저자들은 다음을 제안합니다. 스케일드 웨이트 스탠더드라이제이션 의 확장판인 웨이트 스탠더드라이제이션본질적으로, 웨이트 스탠더드라이제이션은 합성곱 층의 가중치를 정규화하여 가중치의 평균을 0, 분산을 1로 만드는 방법입니다.
- BatchNorm과 견줄 만한 성능 대응항목저자들은 Scaled Weight Standardization과 결합한 노멀라이저 프리 네트워크 구조를 ImageNet의 ResNet에 적용하여, 최초로 최대 288층에 달하는 매우 깊은 네트워크에서도 BatchNorm을 사용한 ResNet과 견줄 만하거나 더 나은 성능을 달성했습니다.
❓: “ 무엇을 의미하나요?이러한 비선형 함수들의 출력 평균은 양수이다위의 두 번째 핵심 bullet에서 “ ”는 무엇을 의미하나요? ReLU는 결국 단지 a 입력이 주어졌을 때 수행되는 연산 그래서 ReLU 활성화의 출력은 항상 양수이므로, 평균이 다음을 향해 이동합니다 이를 완화하기 위해, 저자들은 이후에 자세히 다룰 Scaled Weight Standardization을 도입했습니다 스케일드 웨이트 스탠더드라이제이션 보고서의 섹션
💭: 지금까지는 BatchNorm이 없는 네트워크가 SOTA에 견줄 만한 성능을 내지 못했기 때문에, 왜 이것이 꽤나 흥미로운 소식인지 짐작하실 수 있을 겁니다! 실제로, 이들의 후속 논문게다가 저자들은 새로운 SOTA까지 달성합니다! 하지만, 지금은 성급하게 앞서 나가지 말도록 하죠..
왜 노멀라이저 프리 네트워크가 필요할까요? BatchNorm에는 무엇이 문제일까요?
“BatchNorm에 뭐가 문제죠? 지금까지 거의 모든 네트워크에서 봐 왔는데요…”라고 궁금하시다면, 이 섹션이 그 답을 제공해 드릴 겁니다.
요약하면, BatchNorm은 매우 뛰어난 장점들이 있지만 단점도 있습니다.
논문에서 언급된 핵심 장점들은 다음과 같습니다:
BatchNorm의 여러 가지 장점이 확인되었습니다. 손실 지형을 매끄럽게 만들어 줍니다 (산투르카르 외, 2018), 더 큰 학습률로 학습할 수 있게 해 주며 (비외르크 외, 2018)이며, 배치 통계의 미니배치 추정에서 발생하는 노이즈가 암시적 정규화를 도입합니다 (뤄 외, 2019). 또한 항등 스킵 연결을 갖는 깊은 잔차 네트워크에서 초기화 시 신호 전파를 잘 유지합니다 (디 외, 2020)."
하지만 BatchNorm에는 단점도 많습니다. 또한, 논문에서는 다음과 같이 말합니다:
이 동작은 배치 크기에 강하게 의존하며, 디바이스당 배치 크기가 너무 작거나 너무 클 때 성능이 저하됩니다 (호퍼 외., 2017), 그리고 학습 중 모델의 거동과 추론 시의 거동 사이에 불일치를 초래합니다. BatchNorm은 또한 메모리 오버헤드를 증가시킵니다 (로타 불로 외., 2018), 그리고 구현상의 오류가 자주 발생하는 원인이기도 합니다 (팜 외., 2019). 또한, 서로 다른 하드웨어에서 학습된 배치 정규화 모델을 재현하기가 어려운 경우가 많습니다.
따라서 이 연구 흐름은 다음과 같은 논리를 따른다 — “배치 정규화의 장점은 유지하면서 단점은 없앤 정규화기 없는 네트워크를 찾을 수 있다면, 더 작은 배치 크기로도 학습할 수 있고, 학습과 추론 속도를 높이며 메모리 오버헤드까지 줄일 수 있습니다!"""
또한 정규화기 없는 네트워크에서도 전체 네트워크에 걸쳐 신호 전파가 잘 이루어지길 바랍니다. 그런데 신호 전파를 측정할 방법이 있을까요? 정규화기 없는 네트워크를 BatchNorm을 사용하는 대응 모델과는 어떻게 비교할 수 있을까요? 바로 여기서 신호 전파 플롯이 등장합니다.
신호 전파 플롯
💭: 이건 이 논문에서 제가 가장 좋아하는 섹션 중 하나입니다. 기본적으로 신호 전파 플롯은 네트워크 내부의 “신호 전파”를 측정하는 데 도움을 주는 플롯입니다. 어떻게 할까요? 네트워크 내부의 서로 다른 지점에서(단일 순전파 동안) 몇 가지 통계를 계산하고 그것들을 플롯합니다.
논문에서:
최근 논문들이 ResNet에서의 신호 전파를 이론적으로 분석해 왔지만, 실제로는 실무자가 새로운 모델을 설계하거나 기존 아키텍처에 수정을 제안할 때 특정 심층 네트워크 내부의 서로 다른 깊이에서 은닉 활성값의 스케일을 경험적으로 평가하는 일은 드뭅니다. 이에 비해, 우리는 네트워크 내부의 여러 지점에서 은닉 활성값의 통계를 플로팅해 보는 것이, 무작위 가우시안 입력 배치나 실제 학습 예제 배치로 조건화했을 때 모두 매우 유익하다는 것을 확인했습니다.
저자들은 네트워크 내부의 서로 다른 지점에서 은닉 활성의 통계를 플로팅하는 것이 유용하다고 발견했고, 이러한 플롯을 다음과 같이 명명했습니다 신호 전파 플롯.
💭: 이상적으로는 네트워크 전체에서 은닉 활성의 평균이 0이고 분산이 1이길 바랍니다. 이는 “좋은” 신호 전파의 훌륭한 척도입니다.
저자들은 다음과 같이 차원을 표기한 4차원 입력 및 출력 텐서를 고려합니다 NHWC 여기서 엔 배치 차원을 나타냅니다 C 채널을 나타내며, ㅎ 그리고 더블유 두 개의 공간 차원인 높이와 너비를 나타냅니다.
저자들은 또한 다음과 같은 형태의 항등 잔차 블록을 가정합니다:
여기서, 을(를) 나타냅니다 블록, 가 계산하는 함수를 나타냅니다 잔차 분기.
😕: 수식이 혼란스럽나요? 이 수식은 다음에 보여줄 프리액티베이션 ResNet 블록을 나타냅니다 연구 논문의 그림 4(a). 여기 다음과 같이 BatchNorm, ReLU, Conv 연산으로 구성된 잔차 분기에서의 함수를 나타냅니다 그림 4(a)이는 이후 섹션에서 Residual Block이라고도 합니다.
그런 다음 신호 전파 플롯(SPP)을 생성하기 위해, 저자들은 초기화 스킴에 따라 네트워크를 초기화합니다(예를 들어 He 초기화, 또는 Glorot 초기화, 또는 그 밖의 어떤 스킴이든), 유닛 가우시안 분포에서 샘플링한 입력 예제 배치를 네트워크에 제공합니다.
💭: 간단히 말해, 원하는 네트워크를 구성하고 적절한 초기화 스킴으로 초기화한 뒤, 평균이 0이고 분산이 1인 가우시안 입력으로 한 번 전방 패스를 수행하세요.
이 입력이 네트워크를 통해 전파될 때, 저자들은 각 Residual Block의 끝에서 다음 통계��를 플로팅합니다:
- 평균 채널 제곱 평균NHW 축 전반에서의 평균을 제곱하여 계산하고, 이를 C 축에 걸쳐 평균낸 값.
- 평균 채널 분산, NHW 축 전반에서 채널 분산을 계산한 뒤, 이를 C 축에 걸쳐 평균낸 값.
- 잔차 브랜치 끝에서의 평균 채널 분산, 스킵 경로와 병합하기 전에.
💭: 지금 잠시 멈추고, 눈을 감고 SPP를 감(感)으로 한번 그려보는 건 어떨까요? 어떤 모습일 것 같나요? 무엇을 기대하시나요?
아래에서 그림 1, 우리는 …의 예시를 볼 수 있습니다 신호 전파 플롯 신호 전파를 내부에서 측정하는 ResNet-V2 600 네트워크에서 위의 세 가지 통계를 두 개의 네트워크에 대해 그립니다. 하나는… BN-ReLU-Conv 잔차 블록 내부의 계층 배치 순서와, 다른 하나는 ReLU-BN-Conv 배치 순서…
❓: BN-ReLU-Conv 계층과 ReLU-BN-Conv 배치 순서는 무엇이 다른가요? 간단히 말해 BN과 ReLU 활성화 계층의 위치를 서로 바꿉니다. 연구 논문의 그림 4(a).
🤔: 플롯이 당신의 예상과 일치하나요? 아니라면, 그 이유는 무엇인가요?

그림 1: BatchNorm, ReLU 활성화, He 초기화를 사용한 ResNetV2-600의 초기화 시 신호 전파 플롯. 입력은 평균 0, 분산 1의 가우시안이며 해상도는 512px이다. 검은 점은 스테이지의 끝을 나타낸다. 파란 플롯은 BN-ReLU-Conv 배치 순서를, 빨간 플롯은 ReLU-BN-Conv 배치 순서를 사용한다.
만약 당신이 기억한다면 ResNet 논문각 ResNet 모델은 네 개의 스테이지로 나뉘며, 스테이지마다 블록 수가 서로 다릅니다. ResNet-V2의 경우 총 200개의 잔차 블록으로 구성되며, 각 스테이지는 50개의 블록을 포함합니다.
💭: 이 모델은 저자들이 제공한 코드에서 정의되었습니다 여기.
ResNet-V2 600에는 잔차 블록이 200개 있으므로, 그 결과가 그림 1로 나타납니다. 각 잔차 블록마다 세 가지 통계값을 계산하여 이를 플로팅함으로써.
💭: 그림 1은 꽤 오랫동안 저를 혼란스럽게 했습니다. 저는 ResNet-V2 600에 잔차 블록이 600개 있다고 생각해서, X축의 최댓값이 200인 것을 보고 혼란스러웠죠. 그런데 코드를 확인하고 나니 이제 이해가 됩니다. 여러분에게는 저만큼 헷갈리지 않길 바랍니다.
핵심 관찰사항
⚠️: 여기서는 특히 주의해서 보세요.
그림 1에서 확인할 수 있는 핵심 패턴 — 노말라이저 프리 신경망을 살펴볼 때 큰 도움이 됩니다:
- 그림 1(b)에서 보면, 평균 채널 분산은 깊이에 따라 선형적으로 증가한다 주어진 스테이지에서는 그렇게 되고, 각 전이 블록에서 초기화됩니다.
- 그림 1(a)에서, BN-ReLU-Conv 순서의 경우, 평균 채널 제곱 평균도 깊이에 따라 선형적으로 증가하는 유사한 거동을 보인다.
💭: 호기심 많은 독자분들께는 다음을 참고하시길 권합니다 SkipInit 연구 논문 같은 저자들이 쓴 논문으로, 초기화 시 잔차 분기에서 신호를 축소하기 위해 BatchNorm을 스칼라로 대체한 작업입니다.
💭: 만약 이러한 패턴을 Normalizer‑Free ResNet에서도 모방할 수 있다면, 정규화를 쓰지 않는 새로운 네트워크 역시 잘 학습되고, 정규화된 대응 모델들과 경쟁력 있게 만들 수 있습니다.
PyTorch에서 사용자 정의 모델의 신호 전파 플롯 그리기
좋아요, 그러면 이걸로—이게 끝입니다. 신호 전파 플롯 이론적으로는 그렇습니다. 우리가 직접 이런 플롯을 만들어 재현해 볼 수 있다면 재미있지 않을까요? 사실, 제가 가장 좋아하는 라이브러리 중 하나를 사용하면 가능합니다 — timm — 바로 이것을 하려면!
import torchvisionfrom timm.utils.model import extract_spp_stats, avg_ch_var, avg_ch_var_residual, avg_sq_ch_meanmodel = torchvision.models.resnet50()spp_stats = extract_spp_stats(m,hook_fn_locs=['layer?.?', 'layer?.?', 'layer?.?.bn3'],hook_fns=[avg_sq_ch_mean, avg_ch_var, avg_ch_var_residual])# plot statsfig, ax = plt.subplots(1, 3, figsize=(18,3), sharey=True)ax[0].plot(stats['avg_sq_ch_mean'], label='avg_sq_ch_mean');ax[0].legend(); ax[0].grid();ax[1].plot(stats['avg_ch_var'], label='avg_ch_var');ax[1].legend(); ax[1].grid();ax[2].plot(stats['avg_ch_var_residual'], label='avg_ch_var_residual');ax[2].legend(); ax[2].grid();

💭: 저는 운이 좋게도 SPP를 여기에 기여할 수 있었습니다 timm. ResNet V2 600의 그림 1을 재현할 수 있는 전체 노트북은 다음을 참조하세요 여기.
정규화기 없는 ResNet
💭: 아래 섹션은 위의 내용들보다 조금 더 복잡하게 느껴질 수 있지만, 가장 중요한 부분이기도 합니다. 바로 여기에서 Normalizer‑Free ResNet이 소개되기 때문입니다. 궁금한 점이 있으면 이 섹션을 다시 읽거나 언제든지 저에게 문의하세요.
이제 BatchNorm으로 정규화된 네트워크를 충분히 잘 이해했고, 분석을 도와줄 SPP도 갖췄으니, 정규화는 없지만 신호 전파가 좋고 학습 동안 안정적이며 BatchNorm을 사용한 모델에 필적하는 테스트 정확도에 도달하는 ResNet 변형들을 살펴볼 준비가 되었습니다.
💭: 사실 우리가 필요했던 건 BatchNorm으로 정규화된 네트워크와 SPP에 대한 충분한 이해였습니다. 그것만 있으면 정규화기 없는 네트워크를 개발할 수 있죠. 핵심은, 저자들이 정규화된 ResNet의 SPP를 흉내 내는 정규화기 없는 네트워크를 만든다는 것입니다. 왜일까요? 이에 대한 답은 이미 다음에서 설명되었습니다 Q&A 섹션 아래
- BatchNorm은 잔차 블록의 입력을 다운스케일합니다 입력 신호의 표준편차에 비례하는 계수로.
- 각 잔차 블록은 신호의 분산을 증가시킵니다 대략 일정한 계수만큼. (평균 채널 분산은 선형적으로 증가합니다)
저자들은 다음과 같은 형태의 새로운 네트워크를 설계하여 이러한 효과를 모방할 것을 제안했습니다:
여기서, 을(를) 나타냅니다 잔차 분기와 을(를) 나타냅니다 잔차 분기.
💭: 이는 다음과 같은 잔차 블록과는 대조적이라는 점에 유의하세요 where
저자들은 다음과 같은 방식으로 새로운 노말라이저 프리 네트워크를 설계했습니다:
⚠️: 여기서는 특히 주의하세요.
- , 잔차 분기가 계산하는 함수는 초기화 시 분산 보존이 되도록 매개변수화됩니다. 즉, .
- 은 다음과 같이 선택된 스칼라 함수입니다 . 이는 다음을 보장합니다 단위 분산을 갖습니다.
- 블록 간 분산 증가율을 제어하는 스칼라 하이퍼파라미터입니다.
💭: 네, 정보는 많은데 설명은 훨씬 적다는 걸 알고 있어요. 설명 다음은 문답 형식으로 이어집니다.
문답 안내 정규화기 없는 ResNet
💭: 이 절에서는 제가 처음 이 논문을 읽었을 때 매우 혼란스러웠던 질문들에 답해 보겠습니다.
❓ …라고 말할 때 무슨 뜻인가요? 분산 보존이라는 뜻인가요?
기본적으로 이것은 …이라는 뜻입니다 입력의 분산을 바꾸지 않습니다. 즉,
❓ 왜 우리는 원하는가 분산을 보존한다는 것인가요?
논문에서 발췌:
이 제약 조건 덕분에 네트워크에서 신호가 어떻게 증가하는지 논리적으로 파악할 수 있고, 분산을 해석적으로 추정할 수 있습니다.
그리고 분산을 해석적으로 계산할 수 있으므로, 이는 또한 다음의 값들도 계산할 수 있다는 뜻입니다 해석적으로
❓ 왜 정규화가 없는 네트워크가 ResNet의 SPP 추세를 따라야 하나요?
정규화되지 않은 네트워크에서 신호가 어떻게 전파되어야 하는지는 대체로 설계의 문제입니다. 부록 G.2에서 우리는 처음에 분산을 일정하게 유지하도록 네트워크를 설계하는 방향을 탐색했음을 언급합니다. 사전 지식이 없다면 이것이 더 나은 선택이라고 가정할 수 있습니다. 그러나 우리는 이러한 네트워크가 성능이 떨어진다는 것을 발견했고, 실험에서 잘 작동함이 알려진 신호 전파 템플릿을 모방하는 것이 바람직한 설계 선택이라고 판단했습니다. 이는 우리의 실험 결과로도 뒷받침됩니다.
❓ 왜인가요? 으로 선택된 ?
입력 신호의 표준편차입니다 에게 잔차 블록. 신호를 표준편차로 나누면, 우리는 다음을 확실히 할 수 있습니다. 단위 분산을 가지므로, 안정적인 학습에 바람직합니다!
❓ 무엇인가요? ?
새로운 노말라이저 프리 블록의 설계에서 알 수 있듯이,
그러므로, 이다 로 선택됨 , 그러므로, 단위 분산을 가진다. 그러므로, 분산을 보존하므로, 또한 단위 분산을 가진다.
이제 분산을 계산하면, 다음을 얻는다:
그러므로,
💭: 여기까지 오고 이해하는 데 2주가 걸렸다 , 처음에 이해하지 못해도 괜찮습니다. 이 부분이 이해되지 않으면 알려 주세요! NF-ResNet을 온전히 이해하는 데는 중요하지만, 구현 관점에서는 그리 중요하지는 않습니다.
💭: 이는 논문에 대한 제 이해를 바탕으로 한 것으로, 다른 어디에서도 명시적으로 설명되지 않았습니다.
자, 우리가 정규화 없이도 재현하고 싶은, BatchNorm이 수행하는 두 가지는 다음과 같습니다:
- 잔차 분기의 입력을 다운스케일하기
- 각 잔차 블록마다 신호의 분산을 대략 일정한 비율로 증가시킨다
그러므로, 따라서 α는 블록 간 분산 성장률을 제어하는 스칼라 하이퍼파라미터입니다. 즉, 각 잔차 블록이 신호의 분산을 대략 일정한 비율로 증가시키도록 보장합니다.
또한 잔차 분기의 입력은 이는 …의 비율로 다운스케일된 입력입니다 .
💭: 위에서 노말라이저 프리 네트워크의 원리와 배경을 최대한 쉽게 설명해 보았습니다. 이 부분 중 이해가 잘 되지 않는 부분이 있었다면, 보고서 끝부분에서 언제든지 알려 주세요. :)
PyTorch/Jax에서의 NF-ResNet
로스 와이트먼 는 제가 좋아하는 라이브러리 중 하나에 이미 NF-ResNet을 구현해 주는 훌륭한 작업을 해두었습니다 — timm! 이제 NF-ResNet을 만드는 일은 다음과 같이 간단합니다:
import timmimport torchm = timm.create_model('nf_resnet50')x = torch.randn(1, 3, 224, 224)m(x).shape>> torch.Size([1, 1000])
위에서는 다음을 사용해 간단히 Nf-ResNet 50 모델을 생성합니다 timm 그리고 임의의 입력을 전달해 분류 결과를 얻으면 됩니다. 이 네트워크를 사용자 정의 데이터셋에 적용해 파인튜닝할 수도 있습니다. 동일한 학습 스크립트와 파인튜닝 절차를 그대로 따라 하면 됩니다. 여기 timm 문서에서
💭: 저는 정말 좋아합니다 timm가장 빠르게 성장하는 라이브러리 중 하나이며, 다음에 의해 최신 상태로 유지됩니다 로스. 최신 연구 논문은 곧바로 timm 정말 정말 빠르게! 저는 또한 운 좋게도 …에 참여할 수 있었습니다 timm 문서 프로젝트의 더 자세한 문서는 timm.
💭: PyTorch에서 Nf-ResNet을 실험해 보고 싶다면, 사용해 보세요 timm 이러한 네트워크를 시작하는 가장 쉬운 방법 중 하나입니다. 이 … 또한 훌륭한 자료입니다 (작성자: …) 아유시 타쿠르 PyTorch Lightning을 사용하고 timm 배치 크기를 빠르게 실험해 보기 위해.
코드에서 Nf-ResNet을 빠르고 쉽게 시작하는 방법을 살펴보았으니, 이제 소스 코드를 살펴보겠습니다. 이러한 네트워크의 timm 이러한 네트워크를 처음부터 어떻게 구축할 수 있는지 이해하기 위해.
💭: 우리는 다음을 살펴볼 수도 있었지만 공식 코드 구현 DeepMind의 JAX 구현도 있지만, 단순하게 가기 위해 여기서는 PyTorch로 통일하겠습니다. 우리는 다음의 소스 코드를 사용할 것입니다 timm.
이제 Normalizer Free 블록을 만들기 위해, 우리는 다음의 식을 다시 구성해야 합니다:
아래에서 PyTorch로 해봅시다:
import torchimport torch.nn as nnfrom timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STDfrom timm.models.helpers import build_model_with_cfgfrom timm.models.layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\get_act_layer, get_act_fn, get_attn, make_divisibleclass NormFreeBlock(nn.Module):"""Normalization-Free pre-activation block."""def __init__(self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None,alpha=1.0, beta=1.0, bottle_ratio=0.25, group_size=None, ch_div=1, reg=True, extra_conv=False,skipinit=False, attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0.):super().__init__()out_chs = out_chs or in_chsmid_chs = int(in_chs * bottle_ratio)self.alpha = alphaself.beta = betaif in_chs != out_chs or stride != 1 or dilation != first_dilation:self.downsample = DownsampleAvg(in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, conv_layer=conv_layer)else:self.downsample = Noneself.act1 = act_layer()self.conv1 = conv_layer(in_chs, mid_chs, 1)self.act2 = act_layer(inplace=True)self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)self.act3 = act_layer()self.conv3 = conv_layer(mid_chs, out_chs, 1, gain_init=1. if skipinit else 0.)def forward(self, x):out = self.act1(x) * self.beta# shortcut branchshortcut = xif self.downsample is not None:shortcut = self.downsample(out)# residual branchout = self.conv1(out)out = self.conv2(self.act2(out))out = self.conv3(self.act3(out))out = out * self.alpha + shortcutreturn out
위의 코드는 아래 그림에 나온 것처럼 Pre-Activation ResNet 블록과 유사한 Normalizer Free Bottleneck 블록을 재구현한 것입니다. 하지만 BatchNorm 없이:

💭: 여기서는 독자에게 연습문제로 남겨 두겠습니다. 그림 3 - (a) & (b) 의 코드 구현으로부터 timm 위에서 공유한 Normalizer Free 블록들에 대해. 이 코드 스니펫은 프리-액티베이션 ResNet 블록과 프리-액티베이션 ResNet 전환 블록을 모두 구현할 수 있습니다.
💭: 또한 Nf-ResNet을 만들려면, 앞서 언급한 다양한 ResNet 구성에 맞춰 블록을 반복해서 쌓으면 됩니다. 여기. 이것도 다시 한 번 독자 연습 문제로 남겨 두지만, 필요하다면 더 깊이 파고들 준비가 되어 있습니다. 이렇게 연습 문제로 둔 이유는, 해당 저장소의 소스 코드를 사용해 네트워크를 직접 재구현해 보면서 Nf-ResNet을 이해하는 데 정말 좋은 실습/프로젝트가 되기 때문입니다. timm 가이드로서
💭: 전체 구현을 이해하는 데 필요한 모든 구성 요소는 이미 이 보고서에서 모두 공유했습니다.
스케일드 웨이트 스탠더드라이제이션
💭: 아직 풀리지 않은 수수께끼가 하나 남아 있지만, 이게 마지막입니다.

그림 4: ReLU 활성화를 사용하는 ResNetV2-600 네트워크의 세 가지 변형에 대한 SPP. 빨간색은 ReLU-BN-Conv 순서의 배치 정규화된 네트워크. 초록색은 He 초기화와 α = 1을 사용한 Normalizer-Free 네트워크. 청록색은 같은 Normalizer-Free 네트워크에 Scaled Weight Standardization을 적용한 경우.
저자들은 앞서 소개한 NF-ResNet 네트워크를 구현하고 이를 비교했다 신호 전파 플롯 ReLU-BN-Conv 순서를 사용하는 정규화된 네트워크와 함께(이미 앞에서 본) 그림 1). NF-ResNet은 다음과 같이 초기화했다 He 초기화 그리고 α = 1. 두 SPP는 다음에 표시된다 그림 4 위
보시다시피 그림 4두 SPP는 동일하지 않다. 사실, 다음과 같은 두 가지 예상치 못한 패턴을 관찰할 수 있다.
- NF-ResNet에서는 깊이가 깊어질수록 채널 평균의 제곱값의 평균이 급격히 증가하여, 평균 채널 분산을 초과하는 큰 값에 도달한다. (그림 5(a)와 그림 5(b) 비교)
- NF-ResNet의 경우, 잔차 브랜치에서 관측되는 분산의 스케일은 일관되게 1보다 작다. (그림 5(c)의 녹색)
💭: 앞서 언급했듯이, 네트워크 전반에서 활성값의 평균이 0이고 분산이 1인 상태가 좋은 신호 전파로 간주된다. 따라서 깊이에 따라 채널 평균의 제곱값의 평균이 빠르게 증가하고, 분산 값이 일관되게 1보다 작게 유지되는 현상은 네트워크의 불안정성을 의미한다.
평균 이동의 발생을 방지하고 잔차 브랜치가 … 분산을 보존하도록, 저자들은 … 스케일드 웨이트 스탠더드라이제이션.
💭: 스케일드 웨이트 스탠더드라이제이션은 …의 확장이다 웨이트 스탠더드라이제이션. 본질적으로 웨이트 스탠더드라이제이션은 합성곱 층의 가중치를 정규화하여, 가중치의 평균을 0, 분산을 1로 맞추는 방법이다. 이는 은닉 활성값을 표준화하는 배치 정규화(BatchNorm)와는 다르다는 점에 유의하라. 더 궁금한 독자를 위해, 여기 Yannic Kilcher가 웨이트 스탠더드라이제이션을 설명한 멋진 동영상이다.
스케일드 웨이트 스탠더드라이제이션 다음과 같이 정식화되었다:
여기서 평균 μ와 분산 σ는 합성곱 필터의 팬인 범위 전반에 걸쳐 계산된다. 저자들은 기반 파라미터를 초기화한다 가우시안 분포에서 샘플링한 가중치로부터, γ는 고정 상수이다.
💭: 혼란스러워도 걱정하지 마세요. 아래에 공유한 코드 구현을 보면 훨씬 명확해집니다. 본질적으로는 모든 CNN의 가중치에 대해, 가중치의 평균을 빼고 표준편차로 나누어 표준화합니다.
💭: 스케일드 웨이트 스탠더드라이제이션과 웨이트 스탠더드라이제이션의 유일한 차이는 비선형 함수에 따라 달라지는 고정 상수 γ를 도입한다는 점이다. 더 궁금한 독자는 4.2절을 참고하라 논문 그 이유��� 설명한다.
저자들이 Scaled Weight Standardization을 Normalizer-Free ResNet에 적용했을 때, 다음에 보인 것처럼 그림 4, 스케일드 웨이트 스탠더드라이제이션 초기화 시 채널별 제곱 평균의 증가를 제거한다. 실제로 Scaled Weight Standardization을 적용한 네트워크의 SPP는 다음을 사용하는 배치 정규화된 네트워크의 SPP와 거의 동일하다. ReLU-BN-Conv 순서(빨간색으로 표시됨).
💭: PyTorch로 구현한 Weight Standardization이 있습니다 여기.
❓ 무엇인가요? 감마?
γ는 비선형 함수에 따라 달라지는 고정 상수이며, 여러 비선형 함수에 대해 다음과 같은 값을 갖습니다. 비선형 함수 g에 의존하는 γ의 값은 해당 층이 분산을 보존하도록 선택됩니다.
# from https://github.com/deepmind/deepmind-research/tree/master/nfnets_nonlin_gamma = dict(identity=1.0,celu=1.270926833152771,elu=1.2716004848480225,gelu=1.7015043497085571,leaky_relu=1.70590341091156,log_sigmoid=1.9193484783172607,log_softmax=1.0002083778381348,relu=1.7139588594436646,relu6=1.7131484746932983,selu=1.0008515119552612,sigmoid=4.803835391998291,silu=1.7881293296813965,softsign=2.338853120803833,softplus=1.9203323125839233,tanh=1.5939117670059204,)
이걸로 끝입니다! 스케일드 웨이트 스탠더드라이제이션 정규화기를 쓰지 않는 네트워크가 정규화된 네트워크에 견줄 만한 성능을 내기 위해 마지막으로 맞춰 넣어야 했던 퍼즐 조각이었다!
PyTorch에서의 스케일드 웨이트 스탠더드라이제이션
class ScaledStdConv2d(nn.Conv2d):"""Conv2d layer with Scaled Weight Standardization.Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -https://arxiv.org/abs/2101.08692"""def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,bias=True, gamma=1.0, eps=1e-5, gain_init=1.0):if padding is None:padding = get_padding(kernel_size, stride, dilation)super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,groups=groups, bias=bias)self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)self.eps = epsdef get_weight(self):std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)weight = self.scale * (self.weight - mean) / (std + self.eps)return self.gain * weight
💭: 다시 한 번, 위 구현을 수식과 직접 맞춰 보는 연습으로 남겨두겠습니다. 물론, 혼란스러운 부분이 있으면 언제든지 편하게 연락 주세요.
결론
이 보고서를 통해 독자분들께 NF-ResNet을 설명하고, 네트워크를 코드로 어떻게 구현할 수 있는지도 보여드릴 수 있었기를 바랍니다. 지난 몇 주 동안 셀 수 없이 많은 시간을 들여 NF-ResNet을 직접 읽고 이해한 뒤, 그 지식을 간결하고 읽기 쉬운 보고서로 정제하기 위해 최선을 다했습니다. 이 보고서는 제가 만족할 만한 최종본이 나올 때까지 여러 차례, 완전히 처음부터 다시 쓰며 반복적으로 다듬었습니다. 이 보고서에 쏟은 제 노력이 사랑하는 독자분들께 실제로 도움이 되기를 바랍니다.
이 보고서가 다소 수학적인 내용이 많다는 점은 잘 알고 있습니다. 다만, 수식을 부분별로 나누어 설명하며 최대한 이해하기 쉽게 풀어내려고 노력했습니다. 정규화기 없이 동작하는 네트워크를 가장 잘 설명하기 위해서는 수학적 설명이 꼭 필요했습니다.
읽어 주셔서 감사합니다. 즐거운 실험 되세요!
Add a comment