Home | 简体中文 | 繁体中文 | 杂文 | Github | 知乎专栏 | Facebook | Linkedin | Youtube | 打赏(Donations) | About
知乎专栏

10.7. 保存/加载模型

下面是一个识别手写数字的例子,首先完成训练,然后将模型保存到文件,最后加载模型,使用该模型。

		
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.autograd import Variable
from torchvision.transforms import transforms
from PIL import Image, ImageOps
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from matplotlib import pyplot as plt


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, 5)
        self.conv2 = nn.Conv2d(10, 20, 3)
        self.fc1 = nn.Linear(20 * 10 * 10, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        input_size = x.size(0)
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        x = self.conv2(x)
        x = F.relu(x)
        x = x.view(input_size, -1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)

        return output


# 学习模型
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        # 计算损失率
        loss = F.cross_entropy(output, target)
        #
        # pred = output.max(1, keepdim=True)
        # pred = output.argmax(dim=1)
        # 反向传播
        loss.backward()
        # 参数优化
        optimizer.step()

        if batch_idx % 10 == 0:
            # print("Train epoch: {} loss: {}".format(epoch, loss.item()))
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


# 测试模型
def test(model, device, test_loader):
    model.eval()
    correct = 0.0
    test_loss = 0.0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss = F.cross_entropy(output, target).item()
            # pred = output.max(1, keepdim=True)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
        test_loss /= len(test_loader.dataset)
        print("Test Average loss: {:.4f}, Accuracy: {}/{} {:.0f}%".format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))


def debug():
    train_datasets = datasets.MNIST(root='data', train=False, download=False)
    test_datasets = datasets.MNIST(root='data', train=False, download=False)

    train_dataloader = DataLoader(train_datasets, batch_size=10, shuffle=True)
    test_dataloader = DataLoader(test_datasets, batch_size=10, shuffle=True)

    dataloader = test_dataloader
    plt.imshow(dataloader.dataset.data[0])

    # fig, axes = plt.subplots(3, 3, figsize=(4, 4))
    #
    # for i, ax in enumerate(axes.flat):
    #     ax.imshow(dataloader.dataset.data[i])
    #     ax.axis("off")
    #     ax.set_title(dataloader.dataset.classes[dataloader.dataset.targets[i]])
    #
    plt.show()


def ocr(device):
    # model = torch.load('mnist_cnn.pt', weights_only=False)
    model = Net().to(device)
    model.load_state_dict(torch.load('mnist_cnn.pt', weights_only=False))
    # model = model.to(device)
    model.eval()

    # img = cv2.imread('test.png')  # 读取要预测的图片,读入的格式为BGR
    # img = cv2.resize(img, (28, 28))
    # img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)  # 图片转为灰度图,因为mnist数据集都是灰度图
    # img = np.array(img).astype(np.float32)
    # img = np.expand_dims(img, 0)
    # img = np.expand_dims(img, 0)  # 扩展后,为[1,1,28,28]
    # img = torch.from_numpy(img)
    # img = img.to(device)
    # output = model(Variable(img))
    # prob = F.softmax(output, dim=1)
    # prob = Variable(prob)
    # prob = prob.cpu().numpy()  # 用GPU的数据训练的模型保存的参数都是gpu形式的,要显示则先要转回cpu,再转回numpy模式
    # print(prob)  # prob是10个分类的概率
    # pred = np.argmax(prob)  # 选出概率最大的一个
    # print(pred.item())

    image = Image.open('test.png')
    image = image.resize((28, 28))
    image = ImageOps.grayscale(image)

    image = np.array(image).astype(np.float32)
    image = np.expand_dims(image, 0)
    image = np.expand_dims(image, 0)  # 扩展后,为[1,1,28,28]
    image = torch.from_numpy(image)
    image = image.to(device)

    # print(image)

    output = model(Variable(image))
    prob = F.softmax(output, dim=1)
    prob = Variable(prob)
    prob = prob.cpu().numpy()  # 用GPU的数据训练的模型保存的参数都是gpu形式的,要显示则先要转回cpu,再转回numpy模式
    print(prob)  # prob是10个分类的概率
    pred = np.argmax(prob)  # 选出概率最大的一个
    print(pred.item())


def main():
    # 设置参数

    batch_size = 16
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device("cpu")
    epochs = 10

    # 数据变换
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # 加载数据
    train_datasets = datasets.MNIST(root='data', train=True, transform=transform, download=True)
    test_datasets = datasets.MNIST(root='data', train=False, transform=transform, download=False)

    train_dataloader = DataLoader(train_datasets, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_datasets, batch_size=batch_size, shuffle=True)

    model = Net().to(device)
    optimizer = optim.Adam(model.parameters())
    # scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

    for epoch in range(1, epochs + 1):
        train(model, device, train_dataloader, optimizer, epoch)
        test(model, device, test_dataloader)
        # scheduler.step()

    torch.save(model.state_dict(), "mnist_cnn.pt")


if __name__ == '__main__':
    # debug()
    main()
    ocr('cpu')
		
		
		

完成训练后,使用 torch.save(model.state_dict(), "mnist_cnn.pt") 保存模型

将模型文件复制到生产环境,使用 model.load_state_dict(torch.load('mnist_cnn.pt', weights_only=False)) 加载到内存。之后便可以随时使用。