SVM视频跟踪

# -*- coding: utf-8 -*-
"""
Created on Thu Nov  8 21:44:12 2018

@author: xg
"""

import cv2
import numpy as np
from sklearn.svm import SVC
from skimage import measure,color
import matplotlib.pyplot as plt

font = cv2.FONT_HERSHEY_SIMPLEX
state=0 #0:视频显示,1:静止画面,设置交互,2:跟踪
X_list=[]
y_list=[]
lx1,ly1=0,0
lx2,ly2=0,0
rx1,ry1=0,0
rx2,ry2=1,1
zoomX,zoomY=10,10
colorModel='rgb'

def nothing(x):
    pass
def mouse_callback(event,x,y,flags,param):
    global lx1,ly1,lx2,ly2
    global rx1,ry1,rx2,ry2
    global zoomX,zoomY
    if state==1 and cv2.getTrackbarPos('steps', 'capture')<2:
        if event==cv2.EVENT_LBUTTONDOWN:                       
            if colorModel=='rgb':
                print("clicked at:x=", x,'y=',y,' r=',showimage[y,x,2],'g=',showimage[y,x,1],'b=',showimage[y,x,0])
                X_list.append([np.float64(showimage[y,x,0]),np.float64(showimage[y,x,1]),np.float64(showimage[y,x,2])])
            else:
                print('clicked at:x=', x,'y=',y,' H=',hsvimage[y,x,0],' S=',hsvimage[y,x,1],' V=',hsvimage[y,x,2])
                X_list.append([np.float64(hsvimage[y,x,0]),np.float64(hsvimage[y,x,1]),np.float64(hsvimage[y,x,2])]) 
            if cv2.getTrackbarPos('steps', 'capture')==0:            
                y_list.append(-1)
                cv2.circle(showimage,(x,y),1,(0,0,255),-1)
            else:
                y_list.append(1)
                cv2.circle(showimage,(x,y),1,(255,0,0),-1)
        elif event==cv2.EVENT_MOUSEMOVE:
            #rgb='r='+str(showimage[y,x,2])+',g='+str(showimage[y,x,1])+',b='+str(showimage[y,x,0])
            #cv2.putText(showimage, rgb, (10, 30), font, 1.2, (255, 0, 0), 2)
            #print(rgb)
            zoomX,zoomY=x,y
    #:0,1:正负样本点,2:画ROI,3:画直线
    if state==1 and cv2.getTrackbarPos('steps', 'capture')==2:
        if event==cv2.EVENT_LBUTTONDOWN:
            rx1,ry1=x,y
        elif event==cv2.EVENT_MOUSEMOVE and flags==cv2.EVENT_FLAG_LBUTTON:
            rx2,ry2=x,y

clf=SVC(kernel="linear", C=0.025)

def processing():
    X=np.array(X_list)
    y=np.array(y_list)
    clf.fit(X, y)
    score = clf.score(X, y)
    print('score=',score)
def connected_domain():
    image3,image4=tracking()
    labels=measure.label(image4,connectivity=2)  #8连通区域标记
    dst=color.label2rgb(labels)  #根据不同的标记显示不同的颜色
    print('regions number:',labels.max()+1)  #显示连通区域块数(从0开始标记)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
    ax1.imshow(image4, plt.cm.gray, interpolation='nearest')
    ax1.axis('off')
    ax2.imshow(dst,interpolation='nearest')
    ax2.axis('off')

    fig.tight_layout()
    plt.show()
def tracking():
    #image1=frame.copy()
    image1=frame[np.minimum(ry1,ry2):np.maximum(ry1,ry2),np.minimum(rx1,rx2):np.maximum(rx1,rx2)]

    if colorModel=='rgb':
        XX=image1.reshape(image1.shape[0]*image1.shape[1],3)
    else:
        hsvimage1=cv2.cvtColor(image1,cv2.COLOR_BGR2HSV)
        XX=hsvimage1.reshape(image1.shape[0]*image1.shape[1],3)
    Z=clf.decision_function(XX)
    ZZ=np.array(Z)
    ZZ=ZZ.reshape(image1.shape[0],image1.shape[1])
    image2=np.zeros((image1.shape[0],image1.shape[1]),dtype=np.uint8)
    for i in range(image1.shape[0]):
        for j in range(image1.shape[1]):
            if ZZ[i,j]>0:
                image2[i,j]=0
                #image1[i,j,0]=0
                #image1[i,j,1]=0
                #image1[i,j,2]=0
            else:
                image2[i,j]=255
    #ret,thresh = cv2.threshold(ZZ,127,255,0)
    _,contours,hierarchy=cv2.findContours(image2,1,2)
    cnt=contours[0]
    x,y,w,h=cv2.boundingRect(cnt)
    image2=cv2.rectangle(image2,(x,y),(x+w,y+h),(0,255,0),2)
    return image1,image2

cap = cv2.VideoCapture(0)
ret, frame = cap.read()
image=frame.copy()
showimage=frame.copy()
showimage2=frame.copy()

cv2.namedWindow('capture')
cv2.setMouseCallback('capture',mouse_callback)
cv2.createTrackbar('steps','capture',0,2,nothing) 
#cv2.createTrackbar('zoom','capture',10,50,zooming) 

#cv2.namedWindow('tracking') 

while(1):
    #ret,frame1=cap.read()
    # get a frame
    if state==0:
        ret, frame = cap.read()
        showimage=frame.copy()
        hsvimage=cv2.cvtColor(frame,cv2.COLOR_BGR2HSV)
        cv2.putText(showimage, 'vedio', (0, 30), font, 1.2, (255, 0, 0), 2)
        showimage2=frame.copy()
    if state==1 and cv2.getTrackbarPos('steps', 'capture')<2:
        zoomXMin=np.maximum(0,zoomX-10)
        zoomXMax=np.minimum(zoomX+10,showimage.shape[1])
        zoomYMin=np.maximum(0,zoomY-10)
        zoomYMax=np.minimum(zoomY+10,showimage.shape[0])
        #print('zoomXMin=',zoomXMin,',zoomXMax=',zoomXMax,',zoomYMin=',zoomYMin,',zoomYMax=',zoomYMax)
        zoomimage=showimage.copy()
        zoomimage=zoomimage[zoomYMin:zoomYMax,zoomXMin:zoomXMax]
        showimage2= cv2.resize(zoomimage, (0, 0),fx=10,fy=10,interpolation=cv2.INTER_CUBIC)
        cv2.line(showimage2,(50,100),(150,100),(0,0,255),1)
        cv2.line(showimage2,(100,50),(100,150),(0,0,255),1)
    if state==1 and cv2.getTrackbarPos('steps', 'capture')>1:
        ret, frame = cap.read()
        showimage=frame.copy()
        cv2.line(showimage,(lx1,ly1),(lx2,ly2),(0,255,0),1)
        cv2.rectangle(showimage,(rx1,ry1),(rx2,ry2),(255,0,0),2)
    if state==2:
        ret, frame = cap.read()
        showimage,showimage2=tracking()
#        showimage = cv2.putText(image1, 'tracking', (0, 30), font, 1.2, (255, 0, 0), 2)
#        showimage2=image2 

    cv2.imshow("capture", showimage)
    #cv2.imshow("tracking", frame1)
    
    cv2.imshow('test',showimage2)
    k=cv2.waitKey(1) & 0xFF
    if k==ord('p'):
        state=1
        image=frame.copy()
        showimage=frame.copy()
        showimage = cv2.putText(showimage, 'set up', (0, 30), font, 1.2, (255, 0, 0), 2)
    elif k==ord('s'):
        if state==1:
            cv2.imwrite("pic.jpg", frame)
    elif k==ord('v'):
        state=0
    elif k==ord('c'):
        processing()
    elif k==ord('t'):
        state=2
    elif k==ord('d') and state==2:
        connected_domain()
    elif k==ord('q'):
        break

cap.release()
cv2.destroyAllWindows() 
原文地址:https://www.cnblogs.com/Manuel/p/10430105.html