Loading...
墨滴

希仔

2021/04/06  阅读:21  主题:默认主题

Prediction

Prediction

import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet
def main():
    transform = transforms.Compose(
        [transforms.Resize((3232)),
         transforms.ToTensor(),
         transforms.Normalize((0.50.50.5), (0.50.50.5))])

    classes = ('plane''car''bird''cat',
               'deer''dog''frog''horse''ship''truck')

    net = LeNet()
    net.load_state_dict(torch.load('Lenet.pth'))

    im = Image.open('1.png')
    im = transform(im)  # [C, H, W]
    im = torch.unsqueeze(im, dim=0)  # [N, C, H, W] 增加一个新的维度

    with torch.no_grad():
        outputs = net(im)
        predict = torch.max(outputs, dim=1)[1].data.numpy()
    print(classes[int(predict)])
if __name__ == '__main__':
    main()
plane

希仔

2021/04/06  阅读:21  主题:默认主题

作者介绍

希仔