PyTorch——模型推断——单张推断OpenCV(一)

 1 #coding= utf-8
 2 import os
 3 import torch
 4 from data_pipe import get_data
 5 from model import SimpleNet
 6 import numpy as np
 7 import cv2
 8 from PIL import Image
 9 
10 
11 class Infer(object):
12 
13     def __init__(self):
14         self.model = SimpleNet()
15         self.model.load_state_dict(torch.load("./models/model_10.pth"))
16         self.model.eval()
17 
18     def _infer(self, img_tensor):
19         with torch.no_grad():
20             result = self.model(img_tensor)
21         if result > 0.5:
22             result = 1
23         else:
24             result = 0
25         return result
26 
27     def predict(self, path):
28         img_path_list = [os.path.join(path ,x) for x in os.listdir(path)]
29         for img_path in img_path_list:
30             print(img_path)
31             img = cv2.imread(img_path)
32             img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
33             img_tensor = torch.from_numpy(np.asarray(img)).permute(2,0,1).float()/255.0
34             img_tensor = img_tensor.reshape((1, 3, 32, 32))
35             result = self._infer(img_tensor)
36             print(result)
37 
38 
39 if __name__ == "__main__":
40     path = "./test_images"
41     Infer().predict(path)
原文地址:https://www.cnblogs.com/timelesszxl/p/14595903.html