定义 在分类任务中,通常使用交叉熵损失进行梯度训练。交叉熵损失的作用就是最大化正确标签的对数似然概率。其损失值计算如下:
\[ H(y, p) =-1\times \sum_{k=1}^{K} y_{k}log(p_{k}) \]
当\(y_{k}\) 属于正确类时\(=1\) ,否则\(=0\) 。这会导致两个问题:
它可能会导致过度拟合:如果模型学会为每个训练示例分配全部概率给真值标签,它就不能保证泛化效果 它鼓励最大logit
和所有其他logit
之间的差异变大,这与有界梯度\(\frac {∂l}{∂z_{k}}\) 相结合,降低了模型的迁移能力 标签平滑正则化的目的是防止最大逻辑变得比所有其他逻辑大得多。其实现方式:在交叉熵损失中加入一个独立于训练样本的基于标签的分布\(u(k)\)
\[ y(k) = (1 - \epsilon)δ_{k,y} + \epsilon u(k) \]
\(k\) 表示标签数\(\epsilon\) 表示平滑参数$δ_{k,y} \(表示标签为\) y$的训练样本 从实现上看,LSR鼓励神经网络选择正确的类,并且正确类和其余错误类之间的差别是一致的。这样能够鼓励梯度向正确类靠近的同时远离错误类
在论文中将\(u(k)\) 设置为均匀分布\(u(k) = 1/K\) ,所以
\[ y(k) = (1 - \epsilon)δ_{k,y} + \frac {\epsilon}{K} \]
\(\epsilon可设置为0.1\)
PyTorch实现 在网上找到一个实现,使用KLDivLoss
实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 import torchimport torch.nn as nnclass SmoothLabelCritierion (nn.Module ): """ TODO: 1. Add label smoothing 2. Calculate loss """ def __init__ (self, label_smoothing=0.0 ): super (SmoothLabelCritierion, self).__init__() self.label_smoothing = label_smoothing self.LogSoftmax = nn.LogSoftmax(dim=1 ) if label_smoothing > 0 : self.criterion = nn.KLDivLoss(reduction='batchmean' ) else : self.criterion = nn.NLLLoss() self.confidence = 1.0 - label_smoothing def _smooth_label (self, num_tokens ): one_hot = torch.randn(1 , num_tokens) one_hot.fill_(self.label_smoothing / (num_tokens - 1 )) return one_hot def _bottle (self, v ): return v.view(-1 , v.size(2 )) def forward (self, dec_outs, labels ): scores = self.LogSoftmax(dec_outs) num_tokens = scores.size(-1 ) gtruth = labels.view(-1 ) if self.confidence < 1 : tdata = gtruth.detach() one_hot = self._smooth_label(num_tokens) if labels.is_cuda: one_hot = one_hot.cuda() tmp_ = one_hot.repeat(gtruth.size(0 ), 1 ) tmp_.scatter_(1 , tdata.unsqueeze(1 ), self.confidence) gtruth = tmp_.detach() loss = self.criterion(scores, gtruth) return loss if __name__ == '__main__' : outputs = torch.randn((1 , 10 )) targets = torch.ones(1 ).long() print(outputs) print(targets) tmp = SmoothLabelCritierion(label_smoothing=0.1 ) loss = tmp.forward(outputs, targets) print(loss)
相关阅读