centernet前向推理代码解析

本文代码转自如下链接:
https://zhuanlan.zhihu.com/p/85194783

该代码剥离原来的官方代码,更清晰。同时,数据预处理部分是先固定一边为512,然后另外一边根据比例缩放,然后再填充为512,很值得看下python对这块的操作!
然后后续为了还原到原来的坐标对上述操作再逆操作还原到原图坐标!更值得研究细看!
还有batch操作。
用到的一些技巧函数:

cv2.copyMakeBorder
loaded_iminfos = list(map(prepare_img, image_frames)) ##prepare_img是函数
batch构建:im_batches = [torch.cat(img_tensors[i * opt.batch_size: np.min([(i + 1) * opt.batch_size, len(img_tensors)])]) for i in range(num_batches)]
idx.float()  obj_info[3:5].int()
ratios_select = torch.index_select(torch.Tensor(ratios), 0, output[:, 0].long())
output[:, [1, 3]] -= lefts_select.unsqueeze(1)

加了中间变量信息的注释代码如下:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import cv2
import numpy as np
import torch
import time

image_ext = ['jpg', 'jpeg', 'png']

import sys
CENTERNET_PATH = './src/'
sys.path.insert(0, CENTERNET_PATH)
CENTERNET_PATH = './src/lib/'
sys.path.insert(0, CENTERNET_PATH)

from models.model import create_model, load_model
from models.decode import ctdet_decode


class DefaultConfig(object):
    def __init__(self):
        # self.demo = '../images/17790319373_bd19b24cfc_k.jpg'
        self.demo = '/data_2/2019biaozhushuju/0000-2020/20200514_refdet_small_test_data/sample'
        self.reso = 512

        #self.mean = np.array([0.40789654, 0.44719302, 0.47026115],
                        # dtype=np.float32).reshape(1, 1, 3)
        #self.std = np.array([0.28863828, 0.27408164, 0.27809835],
                       # dtype=np.float32).reshape(1, 1, 3)

        self.mean = np.array([0.408, 0.447, 0.47],
                         dtype=np.float32).reshape(1, 1, 3)
        self.std = np.array([0.289, 0.274, 0.278],
                        dtype=np.float32).reshape(1, 1, 3)

        self.batch_size = 4

        self.use_gpu = torch.cuda.is_available()
        self.device = torch.device("cuda" if self.use_gpu else "cpu")
        if self.use_gpu:
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            torch.set_default_tensor_type('torch.FloatTensor')

        self.arch = 'dla_34'
        self.heads = {'hm': 24, 'wh': 2, 'reg': 2}
        self.head_conv = 256
        self.load_model = '/data_2/project_202009/pytorch_project/CenterNet/centernet_src_202010/CenterNet-master/myfile/model_last-3epoch-multi.pth'
        self.K = 100
        self.conf_thresh = 0.3
        self.downsample_rate = 4



    def parse(self, kwargs):
        '''
        根据字典kwargs 更新 config参数
        '''
        for k, v in kwargs.items():
            if not hasattr(self, k):
                assert ValueError
            setattr(self, k, v)

opt = DefaultConfig()

# load model
print('Creating model...')
model = create_model(opt.arch, opt.heads, opt.head_conv)
model = load_model(model, opt.load_model)
model = model.to(opt.device)
model.eval()

# load imgdir
if os.path.isdir(opt.demo):
    image_names = []
    ls = os.listdir(opt.demo)
    for file_name in sorted(ls):
        ext = file_name[file_name.rfind('.') + 1:].lower()
        image_names.append(os.path.join(opt.demo, file_name))
        #if ext in image_ext:
        #    image_names.append(os.path.join(opt.demo, file_name))
else:
    image_names = [opt.demo]
image_frames = [cv2.imread(image_name) for image_name in image_names]


def prepare_img(frame, color=(0, 0, 0)):
    h, w, c = frame.shape
   ## h = 1152 w = 885  c= 3
    ratio = np.min([opt.reso / h, opt.reso / w]) #0.44444
    new_h = int(h * ratio)  #512
    new_w = int(w * ratio) #393

    if (w, h) != (new_w, new_h):
        frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_CUBIC)  # w,h
        # frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_LINEAR)  # w,h
        # frame = cv2.resize(frame, (new_w, new_h))  # w,h
    #59.5
    dw = (opt.reso - new_w) / 2  # width padding
    #0
    dh = (opt.reso - new_h) / 2  # height padding
    ## 0   0
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) ##59  60
    frame = cv2.copyMakeBorder(frame, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border
     #[512,512,3]
    frame = ((frame / 255. - opt.mean) / opt.std).astype(np.float32)
    frame = frame.transpose(2, 0, 1)
    return torch.Tensor(frame).unsqueeze(0), ratio, left, top #[1,3,512,512]


loaded_iminfos = list(map(prepare_img, image_frames))
img_tensors = [im_info[0] for im_info in loaded_iminfos]
ratios = [im_info[1] for im_info in loaded_iminfos]
lefts = [im_info[2] for im_info in loaded_iminfos]
tops = [im_info[3] for im_info in loaded_iminfos]

leftover = 0 if len(img_tensors) % opt.batch_size == 0 else 1  # len(img_tensors) = 85   batch_size=4       1
num_batches = len(img_tensors) // opt.batch_size + leftover #22

## [4,3,512,512]*21  + [1,3,512,512]*1
im_batches = [torch.cat(
    img_tensors[i * opt.batch_size: np.min([(i + 1) * opt.batch_size, len(img_tensors)])]) for i in
    range(num_batches)]


output = []
for i, batch in enumerate(im_batches):
    batch = batch.to(opt.device)
    batch_size = batch.shape[0]

    pre_process_time = time.time()
    with torch.no_grad():
        mode_output = model(batch)[-1]
        hm = mode_output['hm'].sigmoid_() #[4,24,128,128]
        wh = mode_output['wh'] #[4,2,128,128]
        reg = mode_output['reg']#[4,2,128,128]

        # torch.cuda.synchronize()
        # forward_time = time.time()
        #[4,100,6]
        prediction = ctdet_decode(hm, wh, reg=reg, cat_spec_wh=False, K=opt.K) # batch_size K=100 box+scores+clses
        # torch.cuda.synchronize()
        decode_time = time.time()

        # print('net: {}'.format((forward_time-pre_process_time)/batch_size))
        # print('dec: {}'.format((decode_time-forward_time)/batch_size))
        print('dec: {}'.format((decode_time-pre_process_time)/batch_size))
        #[400,6]
        prediction = prediction.reshape(-1, prediction.shape[2]) # N*K 6

        # aa = torch.arange(batch_size) #[0,1,2,3]
        # bb = torch.arange(batch_size).unsqueeze(1)
        # cc = torch.arange(batch_size).unsqueeze(1).repeat(1, opt.K)  #0*100 1*100 2*100 3*100
                       #[4]              [4,1]            [4,100]         [400,1]
        idx = torch.arange(batch_size).unsqueeze(1).repeat(1, opt.K).view(-1, 1)

        # x = prediction.new(prediction.shape[0], 1).fill_(i)
        prediction = torch.cat((idx.float(), prediction), 1) #[400,7]

    prediction[:, 0] += i * opt.batch_size # 第0为是图片的索引 一张图片有100个一样的索引
    output.append(prediction)

# _ * idx x y x y conf class
try:#before: list(22个)  [400,7]*21  + [100,7]*1
    output = torch.cat(output, 0)  #after [8500,7]
except:
    output = torch.empty((0, 7), dtype=torch.get_default_dtype())

a1 = torch.Tensor(ratios) #[85]
a2 = output[:, 0] #[8500]


ratios_select = torch.index_select(torch.Tensor(ratios), 0, output[:, 0].long()) #[8500]
lefts_select = torch.index_select(torch.Tensor(lefts), 0, output[:, 0].long()) #[8500]
tops_select = torch.index_select(torch.Tensor(tops), 0, output[:, 0].long()) #[8500]
output[:, 1:5] *= opt.downsample_rate

a3 = output[:, [1, 3]] #[8500,2]
a4 = lefts_select.unsqueeze(1) #[8500,1]
# a5 = torch.Tensor([[2,3]])
# a6 = torch.Tensor([[5]])
# print(a5)
# print(a6)
# a7 = a5 - a6
## 增广
output[:, [1, 3]] -= lefts_select.unsqueeze(1)
output[:, [2, 4]] -= tops_select.unsqueeze(1)
output[:, 1:5] /= ratios_select.unsqueeze(1)

frames_shape = [im.shape[:2] for im in image_frames]
frames_shape_select = torch.index_select(torch.Tensor(frames_shape), 0, output[:, 0].long())

for i in range(output.shape[0]):
    output[i, [3, 5]] = torch.clamp(output[i, [3, 5]], 0.0, frames_shape_select[i, 1])
    output[i, [4, 6]] = torch.clamp(output[i, [4, 6]], 0.0, frames_shape_select[i, 0])

# print(output)
def write(obj_info, loaded_imgs):
    obj_img = loaded_imgs[int(obj_info[0])]

    if obj_info[5] < opt.conf_thresh:
        return

    # obj_class = int(obj_info[1]) idx x y x y conf class
    label = '{:.2f}_{}'.format(obj_info[5], obj_info[6])

    x1y1 = tuple(obj_info[1:3].int())
    x2y2 = tuple(obj_info[3:5].int())
    color = (255, 0, 0)
    cv2.rectangle(obj_img, x1y1, x2y2, color, thickness=1)

    t_size = cv2.getTextSize(label, fontFace=cv2.FONT_HERSHEY_PLAIN, fontScale=1, thickness=1)[0]
    x2y2 = (x1y1[0]+t_size[0]+3, x1y1[1]+t_size[1]+4)
    cv2.rectangle(obj_img, x1y1, x2y2, color, thickness=-1)

    cv2.putText(obj_img, label, org=(x1y1[0], x1y1[1]+t_size[1] + 4),
                fontFace=cv2.FONT_HERSHEY_PLAIN, fontScale=1, color=[255,255,255], thickness=1)


list(map(lambda x: write(x, image_frames), output))
os.makedirs('mubai', exist_ok=True)
det_names = ['mubai/det_{}'.format(impath.split('/')[-1]) for impath in image_names]
list(map(cv2.imwrite, det_names, image_frames))

我数据预处理部分直接resize,然后再映射到原图,下面是我写的代码,但是好像哪里有问题啊!!!框不对,我仔细看了一下还是不知道哪里有问题。。。呃呃呃!!

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import cv2
import numpy as np
import torch
import time

image_ext = ['jpg', 'jpeg', 'png']

import sys
CENTERNET_PATH = './src/'
sys.path.insert(0, CENTERNET_PATH)
CENTERNET_PATH = './src/lib/'
sys.path.insert(0, CENTERNET_PATH)

from models.model import create_model, load_model
from models.decode import ctdet_decode


class DefaultConfig(object):
    def __init__(self):
        # self.demo = '../images/17790319373_bd19b24cfc_k.jpg'
        self.demo = '/data_2/2019biaozhushuju/0000-2020/20200514_refdet_small_test_data/sample'
        self.reso = 512
        #self.mean = np.array([0.40789654, 0.44719302, 0.47026115],
                        # dtype=np.float32).reshape(1, 1, 3)
        #self.std = np.array([0.28863828, 0.27408164, 0.27809835],
                       # dtype=np.float32).reshape(1, 1, 3)
        self.mean = np.array([0.408, 0.447, 0.47],
                         dtype=np.float32).reshape(1, 1, 3)
        self.std = np.array([0.289, 0.274, 0.278],
                        dtype=np.float32).reshape(1, 1, 3)

        self.batch_size = 4
        self.use_gpu = torch.cuda.is_available()
        self.device = torch.device("cuda" if self.use_gpu else "cpu")
        if self.use_gpu:
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            torch.set_default_tensor_type('torch.FloatTensor')
        self.class_name = [
           'chedeng', 'chebiao', 'chepai', 'person', 'car_light_detector',
            'pingmu', 'luntaiguige', 'qizhifu', 'abs', 'chechuanshui', 'shouxieqianzi',
            'seat', 'fuzhuzhidong', 'chemenpentu', 'taban', 'zuoyianquandai', 'xsz_rq', 'shuiyin_sj',
            'shuiyin_rq', 'chejiahao', 'wbchepai', 'xsz_fuyezhang', '3cbz', 'button']

        self.arch = 'dla_34'
        self.heads = {'hm': 24, 'wh': 2, 'reg': 2}
        self.head_conv = 256
        self.load_model = '/data_2/project_202009/pytorch_project/CenterNet/centernet_src_202010/CenterNet-master/myfile/model_last-3epoch-multi.pth'
        self.K = 100
        self.conf_thresh = 0.3
        self.downsample_rate = 4

    def parse(self, kwargs):
        '''
        根据字典kwargs 更新 config参数
        '''
        for k, v in kwargs.items():
            if not hasattr(self, k):
                assert ValueError
            setattr(self, k, v)

def init_model(opt):
    # load model
    print('Creating model...')
    model = create_model(opt.arch, opt.heads, opt.head_conv)
    model = load_model(model, opt.load_model)
    model = model.to(opt.device)
    model.eval()
    return model

def prepare_img(img_path,opt):
    image = cv2.imread(img_path)
    h, w, c = image.shape
    resized_image = cv2.resize(image, (opt.reso, opt.reso))
    inp_image = ((resized_image / 255. - opt.mean) / opt.std).astype(np.float32)
    inp_image = inp_image.transpose(2, 0, 1).reshape(1, 3, opt.reso, opt.reso)
    inp_image = torch.from_numpy(inp_image)
    inp_image = inp_image.to(opt.device)
    return inp_image,h,w

def process(model,images,opt,h,w,img_path_):
    output = model(images)[-1]
    hm = output['hm'].sigmoid_()
    wh = output['wh']
    reg = output['reg']
    dets = ctdet_decode(hm, wh, reg=reg, cat_spec_wh=False, K=opt.K) #[1,100,6]  bboxes, scores, clses

    dets[...,:4] *= opt.downsample_rate
    ratio_h = h * 1.0 / opt.reso
    ratio_w = w * 1.0 / opt.reso
    dets[..., 0] *= ratio_w #x1
    dets[..., 2] *= ratio_w #x2
    dets[..., 1] *= ratio_h #y1
    dets[..., 3] *= ratio_h #y2

    score = dets[..., 4]
    score = score >= opt.conf_thresh
    score.squeeze_() #[100]
    dets = dets[:,score,:]
    if 0 == dets.numel():
        return

    dets = dets.detach().cpu().numpy()
    img = cv2.imread(img_path_)
    for i in range(dets.shape[1]):
        box = dets[0, i, :4]
        score = dets[0, i, 4]
        if score < opt.conf_thresh:
            continue
        cls = dets[0, i, 5].astype(np.uint8)

        x1y1 = (box[0].astype(np.uint8),box[1].astype(np.uint8))
        x2y2 = (box[2].astype(np.uint8),box[3].astype(np.uint8))
        color = (255, 0, 0)
        cv2.rectangle(img, x1y1, x2y2, color, thickness=1)

        t_size = cv2.getTextSize(opt.class_name[cls], fontFace=cv2.FONT_HERSHEY_PLAIN, fontScale=1, thickness=1)[0]
        x2y2 = (x1y1[0] + t_size[0] + 3, x1y1[1] + t_size[1] + 4)
        cv2.rectangle(img, x1y1, x2y2, color, thickness=-1)

        cv2.putText(img, opt.class_name[cls], org=(x1y1[0], x1y1[1] + t_size[1] + 4),
                    fontFace=cv2.FONT_HERSHEY_PLAIN, fontScale=1, color=[255, 255, 255], thickness=1)

    cv2.namedWindow("show",0)
    cv2.imshow("show",img)
    cv2.waitKey()


if __name__ == '__main__':
    opt = DefaultConfig()
    model = init_model(opt)

    if os.path.isdir(opt.demo):
        img_path = []
        ls = os.listdir(opt.demo)
        for file_name in sorted(ls):
            ext = file_name[file_name.rfind('.') + 1:].lower()
            img_path.append(os.path.join(opt.demo, file_name))
            #if ext in image_ext:
            #    image_names.append(os.path.join(opt.demo, file_name))
    else:
        img_path = [opt.demo]

    for cnt,img_path_ in enumerate(img_path):
        print(cnt,img_path_)
        inp_image, h, w = prepare_img(img_path_, opt)
        process(model, inp_image, opt, h, w, img_path_)

哎,没有办法,又按照官网代码,用透视变换写了一份代码,该脚本放在CenterNet-master根目录下面运行:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import cv2
import numpy as np
import torch
import time

#import torch.backends.cudnn as cudnn

torch.backends.cudnn.enabled = True

image_ext = ['jpg', 'jpeg', 'png']

import sys
CENTERNET_PATH = './src/'
sys.path.insert(0, CENTERNET_PATH)
CENTERNET_PATH = './src/lib/'
sys.path.insert(0, CENTERNET_PATH)

from models.model import create_model, load_model
from models.decode import ctdet_decode
from utils.image import get_affine_transform
from utils.post_process import ctdet_post_process


def pre_process(opt, image, scale=1.0, meta=None):
    height, width = image.shape[0:2]
    new_height = int(height * scale)
    new_width = int(width * scale)
    if 1:# if self.opt.fix_res:
        inp_height, inp_width = opt.input_h,opt.input_w
        c = np.array([new_width / 2., new_height / 2.], dtype=np.float32)
        s = max(height, width) * 1.0
    else:
        inp_height = (new_height | self.opt.pad) + 1
        inp_width = (new_width | self.opt.pad) + 1
        c = np.array([new_width // 2, new_height // 2], dtype=np.float32)
        s = np.array([inp_width, inp_height], dtype=np.float32)

    trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height])
    resized_image = cv2.resize(image, (new_width, new_height))
    inp_image = cv2.warpAffine(
        resized_image, trans_input, (inp_width, inp_height),
        flags=cv2.INTER_LINEAR)
    inp_image = ((inp_image / 255. - opt.mean) / opt.std).astype(np.float32)

    images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height, inp_width)

    images = torch.from_numpy(images)
    meta = {'c': c, 's': s,
            'out_height': inp_height // opt.down_ratio,
            'out_width': inp_width // opt.down_ratio}
    return images, meta


def process_ctdet(model,images,opt):
    with torch.no_grad():
        output = model(images)[-1]
        hm = output['hm'].sigmoid_()
        wh = output['wh']
        reg = output['reg']

        dets = ctdet_decode(hm, wh, reg=reg, cat_spec_wh=False, K=opt.K)

        return output, dets

def post_process_ctdet(dets, meta, opt, scale=1.0):
    dets = dets.detach().cpu().numpy()
    dets = dets.reshape(1, -1, dets.shape[2])
    dets = ctdet_post_process(
        dets.copy(), [meta['c']], [meta['s']],
        meta['out_height'], meta['out_width'], opt.num_classes)
    for j in range(1, opt.num_classes + 1):
      dets[0][j] = np.array(dets[0][j], dtype=np.float32).reshape(-1, 5)
      dets[0][j][:, :4] /= scale
    return dets[0]

def merge_outputs(opt, detections):
    results = {}
    for j in range(1, opt.num_classes + 1):
        results[j] = np.concatenate([detection[j] for detection in detections], axis=0).astype(np.float32)
    scores = np.hstack([results[j][:, 4] for j in range(1, opt.num_classes + 1)])
    if len(scores) > opt.max_per_image:
        kth = len(scores) - self.max_per_image
        thresh = np.partition(scores, kth)[kth]
        for j in range(1, self.num_classes + 1):
            keep_inds = (results[j][:, 4] >= thresh)
            results[j] = results[j][keep_inds]
    return results

def add_coco_bbox(opt, bbox, cat, image, conf=1, show_txt=True):
    bbox = np.array(bbox, dtype=np.int32)
    # cat = (int(cat) + 1) % 80
    cat = int(cat)
    # print('cat', cat, self.names[cat])
    # c = self.colors[cat][0][0].tolist()
    # if self.theme == 'white':
    #     c = (255 - np.array(c)).tolist()
    txt = '{}{:.1f}'.format(opt.class_name[cat], conf)
    font = cv2.FONT_HERSHEY_SIMPLEX
    cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0]
    cv2.rectangle(
        image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255,0,0), 2)
    if show_txt:
        cv2.rectangle(image,
                      (bbox[0], bbox[1] - cat_size[1] - 2),
                      (bbox[0] + cat_size[0], bbox[1] - 2), (0,0,255), -1)
        cv2.putText(image, txt, (bbox[0], bbox[1] - 2),
                    font, 0.5, (0, 0, 0), thickness=1, lineType=cv2.LINE_AA)

def save_txt(opt,results,path_txt):
    with open(path_txt, "a") as fw:
        for j in range(1, opt.num_classes + 1):
          for bbox_ in results[j]:
            if bbox_[4] >= opt.conf_thresh:
                bbox = bbox_[:4]
                bbox = np.array(bbox, dtype=np.int32)
                cat = int(j-1)

                str_1 = opt.class_name[cat] + " " + str(bbox_[4]) + " " + str(int(bbox[0])) + " " + str(
                    int(bbox[1])) + " " + str(int(bbox[2])) + " " + str(int(bbox[3])) + "
"
                fw.write(str_1)

def show_results(img_path_,opt, results):
    image = cv2.imread(img_path_)
    for j in range(1, opt.num_classes + 1):
      for bbox in results[j]:
        if bbox[4] >= opt.conf_thresh:
          add_coco_bbox(opt, bbox[:4], j - 1, image, bbox[4])
    cv2.namedWindow("show",0)
    cv2.imshow("show",image)
    cv2.waitKey(0)



class DefaultConfig(object):
    def __init__(self):
        # self.demo = '../images/17790319373_bd19b24cfc_k.jpg'
        self.demo = '/data_1/test_data/images-optional'
        self.reso = 512
        self.input_h = 512
        self.input_w = 512
        self.show_result = 1
        self.save_txt = 0
        # self.mean = np.array([0.40789654, 0.44719302, 0.47026115],
        #                 dtype=np.float32).reshape(1, 1, 3)
        # self.std = np.array([0.28863828, 0.27408164, 0.27809835],
        #                dtype=np.float32).reshape(1, 1, 3)
        self.mean = np.array([0.408, 0.447, 0.47],
                         dtype=np.float32).reshape(1, 1, 3)
        self.std = np.array([0.289, 0.274, 0.278],
                        dtype=np.float32).reshape(1, 1, 3)

        self.batch_size = 4
        self.use_gpu = torch.cuda.is_available()
        self.device = torch.device("cuda" if self.use_gpu else "cpu")
        if self.use_gpu:
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            torch.set_default_tensor_type('torch.FloatTensor')
        self.class_name = [
         'cheliang', 'chewei', 'chelian', 'dibiao_20', 'sanjiaojia',
      'qizhibiaozhi', 'motorbike', 'dibiao_0', 'dibiao_qd', 'xiaochebiaozhipai', 'tingchebiaozhipai',
      'fanguangbeixin', 'dibiao_10']

        self.arch = 'dla_34'

        self.num_classes = 13
        self.heads = {'hm':self.num_classes, 'wh': 2, 'reg': 2}
        self.head_conv = 256
        self.load_model = '/data_2/CenterNet-master/myfile/model_last.pth'
        self.K = 100
        self.conf_thresh = 0.2
        self.downsample_rate = 4
        self.down_ratio = 4
        self.max_per_image = 100

    def parse(self, kwargs):
        '''
        根据字典kwargs 更新 config参数
        '''
        for k, v in kwargs.items():
            if not hasattr(self, k):
                assert ValueError
            setattr(self, k, v)

def init_model(opt):
    # load model
    print('Creating model...')
    model = create_model(opt.arch, opt.heads, opt.head_conv)
    model = load_model(model, opt.load_model)
    model = model.to(opt.device)
    model.eval()
    #cudnn.benchmark = True
    return model


if __name__ == '__main__':
    opt = DefaultConfig()
    model = init_model(opt)

    if os.path.isdir(opt.demo):
        img_path = []
        ls = os.listdir(opt.demo)
        for file_name in sorted(ls):
            ext = file_name[file_name.rfind('.') + 1:].lower()
            img_path.append(os.path.join(opt.demo, file_name))
            #if ext in image_ext:
            #    image_names.append(os.path.join(opt.demo, file_name))
    else:
        img_path = [opt.demo]

    torch.cuda.synchronize()
    start = time.time()
    for cnt,img_path_ in enumerate(img_path):
        #print(cnt,img_path_)
        detections = []
        image = cv2.imread(img_path_)
        images, meta = pre_process(opt,image)
        images = images.to(opt.device)
        output, dets = process_ctdet(model, images, opt)

        dets = post_process_ctdet(dets, meta, opt)
        detections.append(dets)
        results = merge_outputs(opt,detections)
        if opt.show_result:
            show_results(img_path_, opt, results)

        if opt.save_txt:
            path_save_dir = 'save_txt'
            os.makedirs(path_save_dir, exist_ok=True)
            pos_1 = img_path_.rfind('/')
            name_txt_1 = img_path_[pos_1 + 1 : -4] + ".txt"
            save_txt(opt, results, path_save_dir + "/" + name_txt_1)
    torch.cuda.synchronize()
    end = time.time()
    time_1 = end - start
    print("ave time=",time_1 * 1000.0 / (cnt+1))
    print("all=",cnt+1)
原文地址:https://www.cnblogs.com/yanghailin/p/13792478.html