标签平滑正则化

定义

在分类任务中,通常使用交叉熵损失进行梯度训练。交叉熵损失的作用就是最大化正确标签的对数似然概率。其损失值计算如下:

\[ H(y, p) =-1\times \sum_{k=1}^{K} y_{k}log(p_{k}) \]

\(y_{k}\)属于正确类时\(=1\),否则\(=0\)。这会导致两个问题:

  1. 它可能会导致过度拟合:如果模型学会为每个训练示例分配全部概率给真值标签,它就不能保证泛化效果
  2. 它鼓励最大logit和所有其他logit之间的差异变大,这与有界梯度\(\frac {∂l}{∂z_{k}}\)相结合,降低了模型的迁移能力

标签平滑正则化的目的是防止最大逻辑变得比所有其他逻辑大得多。其实现方式:在交叉熵损失中加入一个独立于训练样本的基于标签的分布\(u(k)\)

\[ y(k) = (1 - \epsilon)δ_{k,y} + \epsilon u(k) \]

  • \(k\)表示标签数
  • \(\epsilon\)表示平滑参数
  • $δ_{k,y} \(表示标签为\)y$的训练样本

从实现上看,LSR鼓励神经网络选择正确的类,并且正确类和其余错误类之间的差别是一致的。这样能够鼓励梯度向正确类靠近的同时远离错误类

在论文中将\(u(k)\)设置为均匀分布\(u(k) = 1/K\),所以

\[ y(k) = (1 - \epsilon)δ_{k,y} + \frac {\epsilon}{K} \]

\(\epsilon可设置为0.1\)

PyTorch实现

在网上找到一个实现,使用KLDivLoss实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import torch.nn as nn


class SmoothLabelCritierion(nn.Module):
"""
TODO:
1. Add label smoothing
2. Calculate loss
"""

def __init__(self, label_smoothing=0.0):
super(SmoothLabelCritierion, self).__init__()
self.label_smoothing = label_smoothing
self.LogSoftmax = nn.LogSoftmax(dim=1)

# When label smoothing is turned on, KL-divergence is minimized
# If label smoothing value is set to zero, the loss
# is equivalent to NLLLoss or CrossEntropyLoss.
if label_smoothing > 0:
self.criterion = nn.KLDivLoss(reduction='batchmean')
else:
self.criterion = nn.NLLLoss()
self.confidence = 1.0 - label_smoothing

def _smooth_label(self, num_tokens):

one_hot = torch.randn(1, num_tokens)
one_hot.fill_(self.label_smoothing / (num_tokens - 1))
return one_hot

def _bottle(self, v):
return v.view(-1, v.size(2))

def forward(self, dec_outs, labels):
# Map the output to (0, 1)
scores = self.LogSoftmax(dec_outs)
# n_class
num_tokens = scores.size(-1)

gtruth = labels.view(-1)
if self.confidence < 1:
tdata = gtruth.detach()
one_hot = self._smooth_label(num_tokens)
if labels.is_cuda:
one_hot = one_hot.cuda()
tmp_ = one_hot.repeat(gtruth.size(0), 1)
tmp_.scatter_(1, tdata.unsqueeze(1), self.confidence)
gtruth = tmp_.detach()
loss = self.criterion(scores, gtruth)
return loss


if __name__ == '__main__':
outputs = torch.randn((1, 10))
targets = torch.ones(1).long()
print(outputs)
print(targets)

tmp = SmoothLabelCritierion(label_smoothing=0.1)
loss = tmp.forward(outputs, targets)
print(loss)

相关阅读