Skip to main content

クロスエントロピー損失:概要

PyTorchとTensorflowのコードとインタラクティブな視覚化を備えた、クロスエントロピー損失をカバーするチュートリアル
Created on January 11|Last edited on January 11

セクション





👋 クロスエントロピー損失の概要

ニューラルネットワークのトレーニングに使用される最も一般的な損失関数の1つは、クロスエントロピーです。 このブログ投稿では、さまざまなフレームワークを使用してその派生と実装を確認し、wandbを使用してそれらをログに記録して視覚化する方法を学習します。



🧑🏻‍🏫 クロスエントロピー損失とは何ですか?

クロスエントロピー損失は、機械学習の分類モデルのパフォーマンスを測定するために使用されるメトリックです。 損失(またはエラー)は0から1までの数値として測定され、0はパーフェクトモデルです。 モデルを可能な限りクロスエントロピー損失を0に近づけるのは目標です。
クロスエントロピー損失は、ロジスティック損失(またはログ損失、バイナリクロスエントロピー損失と呼ばれることもあります)と互換性があると見なされることがよくありますが、これが常に正しいとは限りません。
クロスエントロピー損失は、機械学習分類モデルが発見された確率分布と予測された分布の差を測定します。 予測に使用できるすべての値が保存されるため、たとえば、コイントスでオッズを探している場合は、その情報が0.5と0.5(表面と裏面)のように保存されます。
一方、バイナリクロスエントロピー損失は、1つの値のみを格納します。 つまり、0.5のみを格納し、他の0.5は別の問題で想定され、最初の確率が0.7の場合、他の確率は0.3であると想定されます。 また、対数を使用します(したがって「ログ損失」)。
これが、バイナリクロスエントロピー損失(または対数損失)が2つの結果しかない場合で使用される理由です。3つ以上ある場合、すぐに失敗することを簡単に理解できます。 3つ以上の分類の可能性があるモデルでの場合にクロスエントロピー損失がよく使用されます。

原理

それでは基本から始めましょう。 深層学習では、通常、勾配ベースの最適化戦略を使用して、損失関数l(f(xi),yi)l \, (f(x_i), \, y_i)を使用してモデル(f(x)f(x))をトレーンします。
ここで (xi,yi)(x_i, y_i) は入出力ペアです。 損失関数は、モデルがどれほど間違っているかを判断し、その「間違っていること」に基づいてそれ自体を改善するのに役立ちます。 これはエラーの尺度です。 トレーニング全体の目標は、このエラー/損失を最小限に抑えることです。
は入出力ペアです。 損失関数は、モデルがどれほど間違っているかを判断し、その「間違っていること」に基づいてそれ自体を改善するのに役立ちます。 これはエラーの尺度です。 トレーニング全体の目標は、このエラー/損失を最小限に抑えることです。
最尤推定 と呼ばれる学習パラダイムがあります。これは、基礎となるデータ分布を学習するために、モデルをトレーンしてパラメーターを推定します。 したがって、損失関数を使用して、モデルがデータ分布にどの程度適合しているかを評価のために扱われます。
クロスエントロピーを使用して、2つの確率分布間の誤差(または差)を測定できます。
例として、二項分類(バイナリ分類)の場合、クロスエントロピーは次の式で表現できます。
l=(ylog(p)+(1y)log(1p))l = - (\,y \, log(p)\,\,+ \,\, (1-y) \, log(1-p)\,)

ここで:
  • pp は予測確率であり、
  • yy はインジケーターです(二項分類の場合は0または1)
特定のデータポイントで何が起こるかを見ていきましょう。 正しいインジケーターがy = 1であるとしましょう。 この場合、
l=(1×log(p)+(11)log(1p))l = - ( \, \,1 \times log(p) + (1 - 1) \, \, log (1- p) \, \,)

l=(1×log(p))l = - ( \, \, 1 \times log(p) \, \,)

損失llの値は、確率ppに依存します。したがって、損失関数は、低い損失で正しい予測(ppの高い値)を与えることでモデルに報酬を与えます。 ただし、確率が低い場合は、エラーの値が高くなる(負の値が大きくなる)ため、モデルに誤った結果のペナルティが課せられます。
マルチ分類(たとえばNクラス)問題の簡単な拡張は次のようになります。-
c=1Nyclog(pc)- \sum_{c=1}^{N} y_c log(p_c)


🧑🏼‍💻 コード

このセクションでは、TensorflowとPyTorchの両方でクロスエントロピー損失を使用してwandbにログを記録する方法について説明します。

🥕 Tensorflowの実装

import tensorflow as tf
from wandb.keras import WandbCallback

def build_model():
...

# Define the Model Architecture
model = tf.keras.Model(inputs = ..., outputs = ...)

# Define the Loss Function -> BinaryCrossentropy or CategoricalCrossentropy
fn_loss = tf.keras.losses.BinaryCrossentropy()

model.compile(optimizer = ..., loss = [fn_loss], metrics= ... )

return model

model = build_model()

# Create a W&B Run
run = wandb.init(...)

# Train the model, allowing the Callback to automatically sync loss
model.fit(... ,callbacks = [WandbCallback()])

# Finish the run and sync metrics
run.finish()

🔥 PyTorchの実装

import wandb
import torch.nn as nn

# Define the Loss Function
criterion = nn.CrossEntropyLoss()

# Create a W&B Run
run = wandb.init(...)

def train_step(...):
...
loss = criterion(output, target)

# Back-propagation
loss.backward()

# Log to Weights and Biases
wandb.log({"Training Loss": loss.item()})

# Finish the run and sync metrics
run.finish()

Run set
2


結論

これで、クロスエントロピー損失に関する短いチュートリアルは終わりです。 wandb機能の完全なスイートを確認するには、この短い5分間のガイドを確認してください。

📚 リソース

  • 🎥 なぜ負の対数確率を使用する必要があるのか疑問に思っている場合は、このビデオをチェックしてください
  • 🧾 より厳密な数学的説明が必要な場合は、これらのblogpost(1)およびblogpost(2)を確認してください。
この動画です:


Iterate on AI agents and models faster. Try Weights & Biases today.