yolov5 测试

yolov5测试

import argparse
import time
from pathlib import Path

import cv2
import torch
import torch.backends.cudnn as cudnn
from numpy import random

import numpy as np

from models.experimental import attempt_load

from utils.datasets import LoadStreams,LoadStreams2, LoadImages,LoadWebcam,letterbox

from utils.general import check_img_size, check_requirements, non_max_suppression, apply_classifier, scale_coords, 
    xyxy2xywh, strip_optimizer, set_logging, increment_path
from utils.plots import plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized


device = select_device('')
augment = False
conf_thres=0.55
iou_thres=0.45
model = attempt_load('yolov5s.pt', map_location=device)
img_size = 640

names = model.module.names if hasattr(model, 'module') else model.names
colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]


def detectionObjectFunction():
    vc = cv2.VideoCapture(2)
    #rval, frame = vc.read()
    while True:
        rval, cameraImg = vc.read()

        
        img = letterbox(cameraImg, new_shape=img_size)[0]
        # Convert
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
        img = np.ascontiguousarray(img)

        ####################################################
        img = torch.from_numpy(img).to(device)
        #img = img.half() if half else img.float()  # uint8 to fp16/32
        im0 = cameraImg.copy()
        
        img = img.half()
        img = img.float()
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        if img.ndimension() == 3:
            img = img.unsqueeze(0)

        # Inference
        t1 = time_synchronized()
        pred = model(img, augment=augment)[0]
        #pred = model(img, augment=opt.augment)[0]

        #print('thres:%d '%conf_thres)
        # Apply NMS
        pred = non_max_suppression(pred, conf_thres, iou_thres)
        #def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
        t2 = time_synchronized()
        
        # Apply Classifier
        
        # Process detections
        for i, det in enumerate(pred):  # detections per image
            # batch_size >= 1
            #if webcam:  
            #    p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
            #else:
            #    p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
            #    
            #p = Path(p)  # to Path
            #save_path = str(save_dir / p.name)  # img.jpg
            #txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}')  # img.txt
            #s += '%gx%g ' % img.shape[2:]  # print string

            # normalization gain whwh
            #gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  
            if len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
                
                # Print results
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    #s += f'{n} {names[int(c)]}s, '  # add to string
        
                # Write results
                for *xyxy, conf, cls in reversed(det):
                    
                    
                    label = f'{names[int(cls)]} {conf:.2f}'
                    #plot_one_box2(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=2)
                    #plot_one_box2(xyxy, im0, label=label, color=(0,255,0), line_thickness=2)
                    #plot_one_box(xyxy, im0, label=label, color=(0,255,0), line_thickness=2)
                    plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=2)
                
            # Print time (inference + NMS)
            print(f'detection time. ({t2 - t1:.3f}s)')

            # Stream results
            #if view_img:
            cv2.imshow("win1", im0)
            #img2 = im0.copy()
            


        ####################################################
        #pass

        if cv2.waitKey(10) == 27:
            break



detectionObjectFunction()

QQ 3087438119
原文地址:https://www.cnblogs.com/herd/p/14638403.html