PyTorch——模型推断——单张推断OpenCV(二)

 1 import os
 2 import torch
 3 from torchvision import transforms
 4 from data_pipe import get_data
 5 from vgg import VGG_13
 6 from resnet18 import ResNet18
 7 import numpy as np
 8 import cv2
 9 from PIL import Image
10 
11 
12 class Infer(object):
13 
14     def __init__(self):
15         self.model = ResNet18()
16         self.model.load_state_dict(torch.load("./models/model_65.pth"))
17         self.model.eval()
18         self.cls = {' 0': 0, ' 1': 1, ' 10': 2, ' 11': 3, ' 12': 4, ' 13': 5, ' 14': 6, ' 15': 7, ' 16': 8, ' 17': 9, ' 18': 10, ' 19': 11, ' 2': 12, ' 20': 13, ' 21': 14, ' 22': 15, ' 23': 16, ' 24': 17, ' 25': 18, ' 26': 19, ' 27': 20, ' 28': 21, ' 29': 22, ' 3': 23, ' 30': 24, ' 31': 25, ' 32': 26, ' 33': 27, ' 34': 28, ' 35': 29, ' 36': 30, ' 37': 31, ' 38': 32, ' 39': 33, ' 4': 34, ' 5': 35, ' 6': 36, ' 7': 37, ' 8': 38, ' 9': 39}
19         self.new_cls = dict(zip(self.cls.values(), self.cls.keys()))
20 
21     def _infer(self, img_tensor):
22         with torch.no_grad():
23             result = self.model(img_tensor)
24         return result
25 
26     def predict(self, path):
27         img_path_list = [os.path.join(path ,x) for x in os.listdir(path)]
28         transform = transforms.Compose([
29             transforms.Resize([224, 224]),
30             transforms.ToTensor(),
31             transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
32         for img_path in img_path_list:
33             img = cv2.imread(img_path)
34             img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
35             img_tensor = transform(img)
36             img_tensor = img_tensor.reshape((1, 3, 224, 224))
37             result = self._infer(img_tensor)
38             _, preds = torch.max(result.data, dim = 1)
39             print(self.new_cls[preds.numpy()[0]].strip())
40 
41 
42 if __name__ == "__main__":
43     path = "./test_images"
44     Infer().predict(path)
原文地址:https://www.cnblogs.com/timelesszxl/p/14611393.html