Skip to main content

クロスエントロピー損失とは?コード付きチュートリアル

クロスエントロピー損失を解説するチュートリアル。PyTorch と TensorFlow でのクロスエントロピー損失関数の実装コードと、インタラクティブな可視化を含みます。
Created on August 11|Last edited on August 11
ニューラルネットワークの学習で最も一般的に使われる損失関数のひとつはクロスエントロピー本記事では、導出と実装について、 を用いて解説します。PyTorchと TensorFlow を使った実装方法に加えて、Weights & Biases でのログ記録と可視化の手順も学びますWeights & Biases
クイックスタートTensorFlow ColabPyTorch Colab


目次



さっそく始めましょう。

クロスエントロピー損失とは?

クロスエントロピー損失は、機械学習において分類モデルの性能を測るための指標です。損失(誤差)は一般に 0 以上で、上限はありません。値が小さいほどモデルは良く、目的は損失をできるだけ小さくすることです。
クロスエントロピー損失は、ロジスティック損失(ログ損失、あるいはバイナリクロスエントロピー損失)と同義に扱われることがよくありますが、これは常に正しいわけではありません。
クロスエントロピー損失は、機械学習の分類モデルが予測した確率分布と、真の分布とのズレを測る指標です。予測では取りうる全てのクラスに対する確率が保持されます。たとえばコイン投げの例なら、表と裏に対して 0.5 と 0.5 の確率が保存されます。
一方、バイナリクロスエントロピー損失は、正例の確率だけを明示的に扱い(残りの 1−p は補数として定まります)、ただひとつの値を保持します。値です。つまり、バイナリの場合は 0.5 のみを保持し、もう一方は補数として 0.5 とみなします。最初の確率が 0.7 なら、残りは 0.3 と仮定します。さらに、対数を用いるため「log loss」と呼ばれます。
このため、バイナリクロスエントロピー損失(またはログロス)は、取り得る結果が2種類しかない場合に用いられます。もしクラスが3つ以上あれば直ちに破綻することは容易に理解できます。そこでクロスエントロピー損失の出番です。クラスが3つ以上ある分類モデルでは、こちらが一般に用いられます。

クロスエントロピー損失の理論

まずは基本から始めましょう。ディープラーニングでは、通常、勾配に基づく最適化手法を用いてモデルを学習します。モデル(例えばf(x)f(x))いくつかを用いて損失関数l(f(xi),yi)l \, (f(x_i), \, y_i)どこで(xi,yi)(x_i, y_i)入力と出力の対応関係が与えられているとき、損失関数はモデルがどれだけ間違っているかを定量化し、その「誤り」に基づいてモデルが改善できるようにする指標です。つまり誤差の尺度です。学習の目的は、この誤差(損失)を最小化することにあります。
損失関数の役割は極めて重要です。誤った出力に対して大きさに見合ったペナルティを与えられないと、収束が遅れ、学習に悪影響を及ぼします。
という学習パラダイムがあります。最尤推定(Maximum Likelihood Estimation, MLE)これにより、モデルはデータの背後にある分布を学習できるよう、パラメータを推定するように訓練されます。したがって、モデルがデータ分布にどれだけ適合しているかを評価するために損失関数を使用します。
クロスエントロピーを用いると、2つの確率分布間の誤差(差異)を測定できます。
例えば、二値分類の場合、クロスエントロピーは次のように定義されます。
l=(ylog(p)+(1y)log(1p))l = - (\,y \, log(p)\,\,+ \,\, (1-y) \, log(1-p)\,)

ここで、
  • ppは予測確率であり、
  • yyは指示変数です(00 または11二値分類の場合に
あるデータ点について何が起きるかを順に見ていきましょう。正解の指示変数が次のようになっているとします。y=1y = 1この場合、
l=(1×log(p)+(11)log(1p))l = - ( \, \,1 \times log(p) + (1 - 1) \, \, log (1- p) \, \,)

l=(1×log(p))l = - ( \, \, 1 \times log(p) \, \,)

損失の値llしたがって、その確率に依存しますppしたがって、損失関数は正解クラスに高い確率を割り当てた予測を報いるように設計されます(すなわち、正解クラスの確率が高いほど損失は小さくなる)。pp) のとき損失は小さくなります。逆に、その確率が低い場合は誤差(負の対数値)が大きくなり、誤った予測に対してモデルへ強いペナルティが課されます。
多クラス分類への素朴な拡張(たとえば…)NNクラス) の問題は次のように定義されます。
c=1Nyclog(pc)- \sum_{c=1}^{N} y_c log(p_c)


クロスエントロピー損失関数を実装する

このセクションでは、クロスエントロピー損失関数の使い方を両方の環境で解説します。TensorFlowandPyTorchおよび Weights & Biases にログを記録します。

TensorFlowでのクロスエントロピー損失関数の実装

import tensorflow as tf
from wandb.keras import WandbCallback

def build_model():
...

# Define the Model Architecture
model = tf.keras.Model(inputs = ..., outputs = ...)

# Define the Loss Function -> BinaryCrossentropy or CategoricalCrossentropy
fn_loss = tf.keras.losses.BinaryCrossentropy()

model.compile(optimizer = ..., loss = [fn_loss], metrics= ... )

return model

model = build_model()

# Create a W&B Run
run = wandb.init(...)

# Train the model, allowing the Callback to automatically sync loss
model.fit(... ,callbacks = [WandbCallback()])

# Finish the run and sync metrics
run.finish()

PyTorchでクロスエントロピー損失関数を実装する

import wandb
import torch.nn as nn

# Define the Loss Function
criterion = nn.CrossEntropyLoss()

# Create a W&B Run
run = wandb.init(...)

def train_step(...):
...
loss = criterion(output, target)

# Back-propagation
loss.backward()

# Log to Weights and Biases
wandb.log({"Training Loss": loss.item()})

# Finish the run and sync metrics
run.finish()

Run set
0


まとめ

これでクロスエントロピー損失に関する短いチュートリアルは終了です。完全な一連の機能をご覧になるには、Weights & Biases の機能、ぜひこの短い動画をご覧ください。5分ガイド

関連リソース

  • 負の対数確率をなぜ使うのか気になる方は、こちらをご覧ください。動画動画
  • より厳密な数学的な説明を知りたい場合は、次を参照してください。
クリックせずに見られるよう、動画をここに貼っておきます。

Mahmoud Limam
Mahmoud Limam •  
Hi thanks for the article. I noticed it says in the beginning that the loss is between 0 and 1, which isn't the case with cross entropy as -log(p) can certainly exceed 1 when p is close enough to 0.
Reply