1. 交叉熵简介
交叉熵 (Cross Entropy) 源自信息论,要理解它需要先了解几个基本概念。
1.1 信息量与熵
信息量衡量一个事件发生时带来的”惊讶程度”。事件 $x$ 的信息量定义为:
\[I(x) = -\log p(x)\]香农熵 (Shannon Entropy) 是信息量的期望值,衡量随机变量的不确定性:
\[H(p) = -\sum_{x} p(x) \log p(x)\]熵越大,分布越均匀,不确定性越高;熵越小,分布越集中,不确定性越低。
1.2 交叉熵
交叉熵衡量用分布 $q$ 编码来自分布 $p$ 的数据时所需的平均比特数:
\[H(p, q) = -\sum_{x} p(x) \log q(x)\]当 $q = p$ 时,$H(p, q) = H(p)$,即交叉熵等于熵本身,此时编码效率最高。
1.3 KL 散度
KL 散度 (Kullback-Leibler Divergence) 衡量两个分布之间的差异:
\[D_{KL}(p \| q) = H(p, q) - H(p) = \sum_{x} p(x) \log \frac{p(x)}{q(x)}\]由于真实分布 $p$ 的熵 $H(p)$ 是常数,最小化交叉熵等价于最小化 KL 散度,这就是交叉熵能作为损失函数的理论基础。
2. 交叉熵损失函数
2.1 二分类交叉熵 (Binary Cross Entropy)
对于二分类问题,真实标签 $y \in {0, 1}$,模型预测概率 $\hat{y} \in (0, 1)$:
\[L = -[y \log(\hat{y}) + (1 - y) \log(1 - \hat{y})]\]- 当 $y = 1$ 时,$L = -\log(\hat{y})$,预测越接近 1 损失越小
- 当 $y = 0$ 时,$L = -\log(1 - \hat{y})$,预测越接近 0 损失越小
2.2 多分类交叉熵 (Categorical Cross Entropy)
对于 $C$ 个类别的分类问题,真实标签为 one-hot 向量 $\mathbf{y}$,模型输出概率分布 $\hat{\mathbf{y}}$:
\[L = -\sum_{c=1}^{C} y_c \log(\hat{y}_c)\]由于 one-hot 编码中只有真实类别 $k$ 对应的 $y_k = 1$,其余为 0,上式简化为:
\[L = -\log(\hat{y}_k)\]这意味着损失仅取决于模型对正确类别的预测概率。
3. 与 Softmax 的关系
3.1 Softmax 函数
Softmax 将模型输出的原始分数 (logits) 转换为概率分布:
\[\sigma(\vec{z})_{i} = \frac{e^{z_i}}{\sum_{j=1}^K e^{z_j}}\]满足 $\sum_i \sigma(\vec{z})_i = 1$ 且 $\sigma(\vec{z})_i > 0$。
3.2 Softmax 的导数
将 Softmax 函数记为 $S_i$,则导数为:
\[\frac{\partial S_i}{\partial z_j} = \begin{cases} S_i(1 - S_j) & i = j \\ -S_i S_j & i \neq j \end{cases}\]3.3 组合梯度推导
将 Softmax 输出代入交叉熵损失 $L = -\log(S_k)$,对 logit $z_j$ 求导:
\[\frac{\partial L}{\partial z_j} = -\frac{1}{S_k} \cdot \frac{\partial S_k}{\partial z_j}\]- 当 $j = k$ 时:$\frac{\partial L}{\partial z_k} = -\frac{1}{S_k} \cdot S_k(1 - S_k) = S_k - 1 = \hat{y}_k - 1$
- 当 $j \neq k$ 时:$\frac{\partial L}{\partial z_j} = -\frac{1}{S_k} \cdot (-S_k S_j) = S_j = \hat{y}_j$
统一写成向量形式:
\[\frac{\partial L}{\partial \mathbf{z}} = \hat{\mathbf{y}} - \mathbf{y}\]这个结果非常优雅——Softmax + 交叉熵的梯度就是预测概率与真实标签之间的差值。这也是为什么二者总是组合使用的原因。
4. 数值稳定性
4.1 溢出问题
直接计算 $e^{z_i}$ 容易导致数值上溢。解决方法是 Log-Sum-Exp 技巧,利用 Softmax 的平移不变性:
\[\log \sum_{j} e^{z_j} = m + \log \sum_{j} e^{z_j - m}, \quad m = \max_j z_j\]减去最大值后指数运算不会溢出,同时结果保持不变。
4.2 PyTorch 的设计选择
PyTorch 中 nn.CrossEntropyLoss 接收原始 logits(而非 Softmax 输出),内部将 Softmax 和交叉熵合并计算。这样做有两个好处:
- 数值稳定:利用 log-sum-exp 技巧避免溢出
- 计算高效:避免先算 Softmax 再取 log 的冗余操作
5. 代码实现
5.1 手动实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
import torch.nn.functional as F
def cross_entropy_manual(logits, targets):
"""手动实现交叉熵损失(含数值稳定处理)"""
# Log-Sum-Exp 技巧
max_val = logits.max(dim=-1, keepdim=True).values
log_sum_exp = max_val + torch.log(torch.exp(logits - max_val).sum(dim=-1, keepdim=True))
log_probs = logits - log_sum_exp
# 取出真实类别对应的 log 概率
loss = -log_probs.gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1)
return loss.mean()
# 测试
logits = torch.randn(4, 5) # batch_size=4, num_classes=5
targets = torch.randint(0, 5, (4,))
print(f"手动实现: {cross_entropy_manual(logits, targets):.4f}")
print(f"PyTorch: {F.cross_entropy(logits, targets):.4f}")
5.2 PyTorch 内置接口
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
import torch.nn as nn
logits = torch.randn(4, 5)
targets = torch.randint(0, 5, (4,))
# 方式一:nn.CrossEntropyLoss(输入为原始 logits)
criterion = nn.CrossEntropyLoss()
loss1 = criterion(logits, targets)
# 方式二:nn.NLLLoss(输入为 log-softmax 输出)
log_probs = torch.log_softmax(logits, dim=-1)
loss2 = nn.NLLLoss()(log_probs, targets)
print(f"CrossEntropyLoss: {loss1:.4f}")
print(f"NLLLoss: {loss2:.4f}") # 结果一致
5.3 常见陷阱:双重 Softmax
1
2
3
4
5
6
# ❌ 错误用法:先 softmax 再传给 CrossEntropyLoss
probs = torch.softmax(logits, dim=-1)
wrong_loss = nn.CrossEntropyLoss()(probs, targets) # 内部会再做一次 softmax!
# ✅ 正确用法:直接传入 logits
correct_loss = nn.CrossEntropyLoss()(logits, targets)
CrossEntropyLoss 内部已经包含 Softmax,如果提前对 logits 做了 Softmax,相当于做了两次,会导致梯度信号被压缩,训练效果变差。
6. 总结
| 概念 | 说明 |
|---|---|
| 交叉熵 | 衡量预测分布与真实分布的差异 |
| 二分类 CE | $-[y\log\hat{y} + (1-y)\log(1-\hat{y})]$ |
| 多分类 CE | $-\sum y_c \log \hat{y}_c$,one-hot 下简化为 $-\log \hat{y}_k$ |
| Softmax + CE 梯度 | $\hat{\mathbf{y}} - \mathbf{y}$,形式简洁,便于反向传播 |
| 数值稳定性 | 使用 log-sum-exp 技巧,PyTorch 接收 logits 而非概率值 |
| 常见错误 | 不要在 CrossEntropyLoss 之前手动调用 Softmax |