1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
| class NIN(nn.Module):
def __init__(self, in_channels=1, out_channels=10): super(NIN, self).__init__()
self.features1 = nn.Sequential( nn.Conv2d(in_channels, 192, (5, 5), stride=1, padding=2), nn.ReLU(), nn.Conv2d(192, 160, (1, 1), stride=1, padding=0), nn.ReLU(), nn.Conv2d(160, 96, (1, 1), stride=1, padding=0), nn.ReLU(), nn.MaxPool2d(2, stride=2), nn.Dropout2d() ) self.features2 = nn.Sequential( nn.Conv2d(96, 192, (5, 5), stride=1, padding=2), nn.ReLU(), nn.Conv2d(192, 192, (1, 1), stride=1, padding=0), nn.ReLU(), nn.Conv2d(192, 192, (1, 1), stride=1, padding=0), nn.ReLU(), nn.MaxPool2d(2, stride=2), nn.Dropout2d() ) self.features3 = nn.Sequential( nn.Conv2d(192, 192, (3, 3), stride=1, padding=1), nn.ReLU(), nn.Conv2d(192, 192, (1, 1), stride=1, padding=0), nn.ReLU(), nn.Conv2d(192, out_channels, (1, 1), stride=1, padding=0), nn.ReLU(), )
self.gap = nn.AvgPool2d(8)
def forward(self, inputs): x = self.features1(inputs) x = self.features2(x) x = self.features3(x) x = self.gap(x)
return x.view(x.shape[0], x.shape[1])
|