Skip to main content

Pytorchでのレイヤー正規化 (例付き)

コードとインタラクティブなパネルを備えた、Pytorchでのレイヤー正規化の簡単でわかりやすい入門書です。
Created on March 14|Last edited on May 30
機械学習アルゴリズムのトレーニングは、特に現実世界のデータセットの場合、困難な作業になる可能性があります。陥りやすい落とし穴は数多くありますが、その中でも中間アクティベーションの統計的安定化が上位に挙げられることがよくあります。
このレポートでは、統計的安定化に使用される一般的な方法の1つについて簡単に説明します。レイヤー正規化。このレポートは、バッチ正規化から始まった機械学習の正規化に関するシリーズの続編です。年末までに最後の2つを公開したいと考えています。

Table of Contents



そもそもレイヤー正規化とは何ですか?

問題

すでにご存知の通り、機械学習モデルのトレーニングは確率的 (ランダム) プロセスです。これは、初期化、さらには最も一般的なオプティマイザー (SGD、Adamなど) が本質的に確率的であるという事実に由来しています。
このため、MLの最適化には、ソリューションの風景の中で鋭い(一般化できない)最小値に収束し、大きな勾配が発生するリスクがある傾向があります。簡単に言えば、アクティベーション(別名、非線形層からの出力) は、大きな値に達する傾向があります。これは、控えめに言っても理想的ではありません。これを修正する最も一般的な方法は、バッチ正規化を使用することです。
ただし、ここには落とし穴があります。バッチ正規化は、バッチ数が減るとすぐに失敗します。最新のMLアルゴリズムのデータ解像度が向上するにつれて、これは大きな問題になります。データをメモリに収めるためにバッチサイズを小さくする必要があります。さらに、バッチ正規化を実行するには、各層でのアクティベーションの移動平均・分散を計算する必要があります。この方法は、層の統計的推定値がシーケンスの長さ(つまり、同じ隠れ層が呼び出される回数)に依存する反復モデル(RNNなど)には適用できません。

ソリューション

レイヤー正規化は、アクティベーションのバッチ内の各項目の統計(つまり、平均と分散)を計算し、これらの統計的推定値で各項目を正規化することにより、これら両方の問題に対する簡単な解決策を提供します。
具体的には、形状のサンプルを考えると[N,C,H,W][N, C, H, W] LayerNormは、各バッチの形状のすべての要素の平均と分散を計算します。[C,H,W][C,H,W] (以下の図を参照)。この方法は、上記の両方の問題を解決するだけでなく、推論のために平均と分散を保存する要件 (バッチ正規化層がトレーニング中に行う必要があるもの) も削除します。


コードを見てみましょう

PyTorchでのレイヤー正規化の実装は比較的単純なタスクです。これを行うために、 torch.nn.LayerNorm()を使用できます。
ただし、畳み込みニューラル ネットワークの場合は、畳み込みの実行中に使用されるパラメーターを考慮して、出力アクティベーションマップの形状を計算する必要もあります。簡単な実装は、以下のcalc_activation_shape()機能で提供されます。(プロジェクト内でご自由に再利用してください)。
class Network(torch.nn.Module):
@staticmethod
def calc_activation_shape(
dim, ksize, dilation=(1, 1), stride=(1, 1), padding=(0, 0)
):
def shape_each_dim(i):
odim_i = dim[i] + 2 * padding[i] - dilation[i] * (ksize[i] - 1) - 1
return (odim_i / stride[i]) + 1

return shape_each_dim(0), shape_each_dim(1)

def __init__(self, idim, num_classes=10):
self.layer1 = torch.nn.Conv2D(3, 5, 3)
ln_shape = Network.calc_activation_shape(idim, 3) # <--- 畳み込みの出力の形状を計算します
self.norm1 = torch.nn.LayerNorm([5, *ln_shape]) # <--- C、H、Wでのアクティベーションを正規化します(上の図を参照)
self.layer2 = torch.nn.Conv2D(5, 10, 3)
ln_shape = Network.calc_activation_shape(ln_shape, 3)
self.norm2 = torch.nn.LayerNorm([10, *ln_shape])
self.layer3 = torch.nn.Dense(num_classes)

def __call__(self, inputs):
x = F.relu(self.norm1(self.layer1(input)))
x = F.relu(self.norm2(self.layer2(x)))
x = F.sigmoid(self.layer3(x))
return x
次のグラフに示すように、レイヤー正規化を使用した場合と使用しない場合で、Colabノートブックで提供されているモデルのベンチマークを行います。レイヤー正規化はここで非常にうまく機能します。(注: 平均 4回の実行を計算します。実線は、これらの実行の平均結果を示します。明るいほうの色は標準偏差を示します。)

Run set
8


結論

レイヤー正規化についての簡単な紹介をご覧いただき、ありがとうございました。今後数週間にわたって、その他の一般的な正規化手法をいくつか紹介し、最後に、いつどれを使用するかについてのいくつかのヒントやテクニックを含むメタレポートでシリーズを終了します (もちろん、ここでは文脈は非常に重要です)。
取り上げてほしいその他の基本的なテクニックについてのリクエストがある場合は、以下のコメントを投稿してください。それではまた次回!
Iterate on AI agents and models faster. Try Weights & Biases today.