1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
| # -*- coding: utf-8 -*-
import os import time import copy import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader import torchvision.transforms as transforms from torchvision.datasets import ImageFolder import torchvision
import models.zfnet as zfnet import utils.util as util
data_root_dir = '../data/train_val/' model_dir = '../data/models/'
def load_data(root_dir): transform = transforms.Compose([ transforms.Resize((227, 227)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])
data_loaders = {} dataset_sizes = {} for phase in ['train', 'val']: phase_dir = os.path.join(root_dir, phase)
data_set = ImageFolder(phase_dir, transform=transform) data_loader = DataLoader(data_set, batch_size=128, shuffle=True, num_workers=8)
data_loaders[phase] = data_loader dataset_sizes[phase] = len(data_set)
return data_loaders, dataset_sizes
def train_model(model, criterion, optimizer, scheduler, dataset_sizes, dataloaders, num_epochs=25, device=None): since = time.time()
best_model_wts = copy.deepcopy(model.state_dict()) best_acc = 0.0
loss_dict = {'train': [], 'val': []} acc_dict = {'train': [], 'val': []} for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10)
# Each epoch has a training and validation phase for phase in ['train', 'val']: if phase == 'train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode
running_loss = 0.0 running_corrects = 0
# Iterate over data. for inputs, labels in dataloaders[phase]: inputs = inputs.to(device) labels = labels.to(device)
# zero the parameter gradients optimizer.zero_grad()
# forward # track history if only in train with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels)
# backward + optimize only if in training phase if phase == 'train': loss.backward() optimizer.step()
# statistics running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) if phase == 'train': scheduler.step()
epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = running_corrects.double() / dataset_sizes[phase] loss_dict[phase].append(epoch_loss) acc_dict[phase].append(epoch_acc)
print('{} Loss: {:.4f} Acc: {:.4f}'.format( phase, epoch_loss, epoch_acc))
# deep copy the model if phase == 'val' and epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict())
print()
time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print('Best val Acc: {:4f}'.format(best_acc))
# load best model weights model.load_state_dict(best_model_wts) return model, loss_dict, acc_dict
if __name__ == '__main__': data_loaders, data_sizes = load_data(data_root_dir) print(data_sizes)
res_loss = dict() res_acc = dict() for name in ['alexnet', 'zfnet']: if name == 'alexnet': model = torchvision.models.AlexNet(num_classes=20) else: model = zfnet.ZFNet(num_classes=20) device = util.get_device() model = model.to(device)
criterion = nn.CrossEntropyLoss() # optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9) optimizer = optim.Adam(model.parameters(), lr=1e-3) lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
best_model, loss_dict, acc_dict = train_model(model, criterion, optimizer, lr_scheduler, data_sizes, data_loaders, num_epochs=50, device=device) # 保存最好的模型参数 util.check_dir(model_dir) torch.save(best_model.state_dict(), os.path.join(model_dir, '%s.pth' % name))
res_loss[name] = loss_dict res_acc[name] = acc_dict
print('train %s done' % name) print()
util.save_png('loss', res_loss) util.save_png('acc', res_acc)
|