Skip to main content

クロスエントロピー損失とは?コード付きチュートリアル

クロスエントロピー損失を解説するチュートリアル。PyTorch と TensorFlow でクロスエントロピー損失関数を実装するコード例と、対話的な可視化を含みます。
Created on August 12|Last edited on August 12
ニューラルネットワークの学習で最も一般的に用いられる損失関数のひとつはクロスエントロピー本記事では、その導出と実装について、次のツールを用いて解説します PyTorch と TensorFlow を扱い、さらにそれらを用いた指標を Weights & Biases(W&B)でロギングして可視化する方法を学びます Weights & Biases
クイックスタート: TensorFlow Colab | PyTorch Colab


目次



さあ、始めましょう!

クロスエントロピー損失とは?

クロスエントロピー損失は、機械学習において分類モデルの性能を評価するために用いられる指標です。損失(または誤差)は通常、0以上の値で測られ、値が小さいほどモデルは良好です。理想は0に近づけることです。
クロスエントロピー損失は、ロジスティック損失(ログ損失、あるいは二値クロスエントロピー損失と呼ばれることもあります)と同じものとして扱われがちですが、常に正しいわけではありません。
クロスエントロピー損失は、機械学習の分類モデルにおける発見された確率分布と予測分布の差を測る指標です。予測の全ての可能な値に対する確率を保持するため、例えばコイン投げの勝率を求める場合は、表と裏に対してそれぞれ 0.5 と 0.5 の確率を格納します。
一方、二値クロスエントロピー損失では、片方の確率が決まればもう一方は 1−p で定まるため、実質的に1つの確率のみを扱います。値のみを保持します。つまり 0.5 だけを扱い、もう一方の 0.5 は自動的に補われます。例えば最初の確率が 0.7 なら、残りは 0.3 とみなされます。また、対数を用いるため「ログ損失」とも呼ばれます。
このため、二値クロスエントロピー(またはログ損失)は、取りうる結果が2つだけの場合に用いられます。しかし、クラスが3つ以上あると直ちに破綻することは容易に想像できます。そこで登場するのがクロスエントロピー損失で、3クラス以上の多クラス分類を扱うモデルで広く用いられます。

クロスエントロピー損失の理論背景

まず基本から始めましょう。ディープラーニングでは一般に、勾配に基づく最適化手法を用いてモデルを学習します。モデル(例えばf(x)f(x)いくつかを用いて損失関数l(f(xi),yi)l \, (f(x_i), \, y_i)どこで(xi,yi)(x_i, y_i)いくつかの入力と出力の組が与えられているとします。損失関数は、モデルがどれだけ「間違っている」かを評価し、その「誤り度」に基づいてモデルが改良できるようにするために用いられます。これは誤差の尺度です。学習を通しての目標は、この誤差(損失)を最小化することです。
損失関数の役割は非常に重要です。誤った出力に対して、その誤りの大きさに見合う罰則を与えられない場合、収束が遅れ、学習に悪影響を及ぼす可能性があります。
…と呼ばれる学習パラダイムがあります。 最尤推定(MLE) これは、モデルが潜在的なデータ分布を学習できるように、そのパラメータを���定するよう訓練するものです。したがって、モデルがデータ分布にどれだけ適合しているかを評価するために損失関数を用います。
クロスエントロピーを用いると、2つの確率分布間の誤差(あるいは差異)を測定できます。
例えば、二値分類の場合、クロスエントロピーは次のように定義されます。
l=(ylog(p)+(1y)log(1p))l = - (\,y \, log(p)\,\,+ \,\, (1-y) \, log(1-p)\,)

ここで
  • ppは予測確率であり、
  • yyは指示関数(00 または11二値分類の場合)
特定のデータ点について、実際にどうなるかを見ていきましょう。正解の指示関数が、つまり次のような状態だとします。y=1y = 1この場合、
l=(1×log(p)+(11)log(1p))l = - ( \, \,1 \times log(p) + (1 - 1) \, \, log (1- p) \, \,)

l=(1×log(p))l = - ( \, \, 1 \times log(p) \, \,)

損失の値llしたがって確率に依存しますpp。したがって、損失関数は、正しい予測に高い確率を与えた場合にはモデルを報いるようになります(高い確率のときは損失が小さくなる)pp)で損失は小さくなります。逆に、その確率が低い場合は損失が大きくなり(対数項は負になり、負号によって損失が増える)、誤った予測に対してモデルが強く罰せられることになります。
多クラス分類への単純な拡張(たとえばNNクラス)問題は次のように定義できます。
c=1Nyclog(pc)- \sum_{c=1}^{N} y_c log(p_c)


クロスエントロピー損失関数を実装する

このセクションでは、両方の環境でクロスエントロピー損失関数を使う方法を説明します。 TensorFlowPyTorch および Weights & Biases にログを記録します。

TensorFlowでクロスエントロピー損失を実装する

import tensorflow as tf
from wandb.keras import WandbCallback

def build_model():
...

# Define the Model Architecture
model = tf.keras.Model(inputs = ..., outputs = ...)

# Define the Loss Function -> BinaryCrossentropy or CategoricalCrossentropy
fn_loss = tf.keras.losses.BinaryCrossentropy()

model.compile(optimizer = ..., loss = [fn_loss], metrics= ... )

return model

model = build_model()

# Create a W&B Run
run = wandb.init(...)

# Train the model, allowing the Callback to automatically sync loss
model.fit(... ,callbacks = [WandbCallback()])

# Finish the run and sync metrics
run.finish()

PyTorchでクロスエントロピー損失を実装する

import wandb
import torch.nn as nn

# Define the Loss Function
criterion = nn.CrossEntropyLoss()

# Create a W&B Run
run = wandb.init(...)

def train_step(...):
...
loss = criterion(output, target)

# Back-propagation
loss.backward()

# Log to Weights and Biases
wandb.log({"Training Loss": loss.item()})

# Finish the run and sync metrics
run.finish()

Run set
0


まとめ

以上でクロスエントロピー損失に関する短いチュートリアルは終了です。全機能一式を確認するには Weights & Biases の機能、ぜひこの短いものをご覧ください 5分間ガイド

関連リソース

  • 負の対数確率をなぜ使うのか気になる方は、こちらをご覧ください 動画 🎥
  • より厳密な数理的説明を知りたい方は、こちらをご覧ください。
クリックせずに見られるよう、こちらに動画を埋め込んでおきます。

Mahmoud Limam
Mahmoud Limam •  
Hi thanks for the article. I noticed it says in the beginning that the loss is between 0 and 1, which isn't the case with cross entropy as -log(p) can certainly exceed 1 when p is close enough to 0.
Reply