Skip to main content

クロスエントロピー損失とは?コード付きチュートリアル

クロスエントロピー損失を解説するチュートリアル。PyTorch と TensorFlow でクロスエントロピー損失関数を実装するコード例に加え、インタラクティブな可視化も含みます。
Created on August 12|Last edited on August 12
ニューラルネットワークの学習で最も一般的に用いられる損失関数のひとつはクロスエントロピー本記事では、その導出と実装を取り上げ、使用方法を紹介しますPyTorchと TensorFlow を用いた実装方法に加え、Weights & Biases を使ってそれらを記録し可視化する方法も学びますWeights & Biases
クイックスタート:TensorFlow の Colab|PyTorch の Colab


目次



さっそく始めましょう!

クロスエントロピー損失とは何か?

クロスエントロピー損失は、分類モデルの性能を評価するために用いられる機械学習の指標です。損失(誤差)は 0 から 1 の数値で測定され、0 は完全なモデルを意味します。一般には、モデルの損失をできるだけ 0 に近づけることが目標です。
クロスエントロピー損失は、しばしばロジスティック損失(ログ損失、あるいはバイナリクロスエントロピー損失)と同じものとみなされますが、常に正しいわけではありません。
クロスエントロピー損失は、機械学習の分類モデルにおける予測確率分布と真の確率分布のずれを測る指標です。予測のあらゆる値に対する確率が保存されるため、例えばコイン投げの確率を求める場合には、表と裏に対してそれぞれ 0.5 と 0.5 の情報が格納されます。
一方、二値クロスエントロピー損失は、次の1つだけを保持します値です。つまり、0.5 だけを保持し、もう一方の 0.5 は暗黙に仮定されます(最初の確率が 0.7 なら、もう一方は 0.3 とみなされます)。また、対数を用いるため「ログ損失」とも呼ばれます。
このため、二値クロスエントロピー損失(またはログ損失)は、可能な結果が2つだけの場合に用いられます。三つ以上のクラスがあると直ちに不適切になることは容易に想像できるでしょう。そこで一般に用いられるのがクロスエントロピー損失で、三つ以上の分類候補を扱うモデルに適しています。

クロスエントロピー損失の理論背景

まず基本から始めましょう。ディープラーニングでは、通常、勾配に基づく最適化手法を用いてモデルを学習します。モデル(例えば)f(x)f(x))いくつかを用いて損失関数l(f(xi),yi)l \, (f(x_i), \, y_i)ここで(xi,yi)(x_i, y_i)ある入力と出力の組を考えます。損失関数は、モデルがどれほど「誤っている」かを測り、その誤りに基づいてモデルが改善できるようにするために用いられる指標です。すなわち、誤差の尺度です。学習を通じた目標は、この誤差(損失)を最小化することです。
損失関数の役割は非常に重要です。誤った出力に対して大きさに見合ったペナルティを与えられないと、収束が遅れ、学習に悪影響を及ぼすことがあります。
という学習パラダイムがあります最尤推定(MLE)これは、モデルが基礎となるデータ分布を学習できるよう、そのパラメータを推定するように学習する��いう意味です。したがって、モデルがデータ分布にどれだけ適合しているかを評価するために、損失関数を用います。
クロスエントロピーを用いると、2つの確率分布の誤差(差異)を測定できます。
例えば、二値分類の場合、クロスエントロピー損失は次のように定義されます。
l=(ylog(p)+(1y)log(1p))l = - (\,y \, log(p)\,\,+ \,\, (1-y) \, log(1-p)\,)

ここで:
  • ppは予測確率であり、
  • yyである指示関数00 または112クラス分類の場合)
特定のデータ点について何が起こるかを見ていきましょう。正解ラベルが、すなわち…とします。y=1y = 1この場合、
l=(1×log(p)+(11)log(1p))l = - ( \, \,1 \times log(p) + (1 - 1) \, \, log (1- p) \, \,)

l=(1×log(p))l = - ( \, \, 1 \times log(p) \, \,)

損失の値llしたがって、その確率に依存しますppしたがって、損失関数は、正しい予測に高い確率を割り当てた場合にはモデルを報い、pp損失が小さい場合は良好ですが、予測確率が低いと誤差の値(負の対数の値)は大きくなり、その結果、誤った出力に対してモデルに強いペナルティが与えられます。
多クラス分類への簡単な拡張(例えば、)NN(クラス)問題は次のように定義されます。
c=1Nyclog(pc)- \sum_{c=1}^{N} y_c log(p_c)


クロスエントロピー損失関数の実装方法

このセクションでは、PyTorch と TensorFlow の両方でクロスエントロピー損失関数を使用する方法を説明します。TensorFlowそしてPyTorchおよび Weights & Biases にログを記録します。

TensorFlow でのクロスエントロピー損失関数の実装

import tensorflow as tf
from wandb.keras import WandbCallback

def build_model():
...

# Define the Model Architecture
model = tf.keras.Model(inputs = ..., outputs = ...)

# Define the Loss Function -> BinaryCrossentropy or CategoricalCrossentropy
fn_loss = tf.keras.losses.BinaryCrossentropy()

model.compile(optimizer = ..., loss = [fn_loss], metrics= ... )

return model

model = build_model()

# Create a W&B Run
run = wandb.init(...)

# Train the model, allowing the Callback to automatically sync loss
model.fit(... ,callbacks = [WandbCallback()])

# Finish the run and sync metrics
run.finish()

PyTorch でのクロスエントロピー損失関数の実装

import wandb
import torch.nn as nn

# Define the Loss Function
criterion = nn.CrossEntropyLoss()

# Create a W&B Run
run = wandb.init(...)

def train_step(...):
...
loss = criterion(output, target)

# Back-propagation
loss.backward()

# Log to Weights and Biases
wandb.log({"Training Loss": loss.item()})

# Finish the run and sync metrics
run.finish()

Run set
0


まとめ

以上でクロスエントロピー損失に関する短いチュートリアルは終了です。より充実した機能一式をご覧になるには、Weights & Biases の機能こちらの短い資料をご覧ください5分ガイド

関連リソース

  • 負の対数確率をなぜ使うのか気になる方は、こちらをご覧ください。動画🎥
  • より厳密な数学的な説明が必要な場合は、次を参照してください。
クリックせずに見られるよう、動画をこちらに埋め込ん���おきます。

Mahmoud Limam
Mahmoud Limam •  
Hi thanks for the article. I noticed it says in the beginning that the loss is between 0 and 1, which isn't the case with cross entropy as -log(p) can certainly exceed 1 when p is close enough to 0.
Reply