数据处理(超分辨率)

一、 切片 24*24 和 96*96

import os
import random
import cv2

def crop_mod(img, scale):
    h,w = img.shape[0], img.shape[1]
    a = h%scale
    b = w%scale
    if len(img.shape)==2:
        img = img[0: h-a, 0:w-b]
    else:
        img = img[0:h-a, 0:w-b,:]
    return img
    

nir_path = r'C:Users13141Desktoplow_nir_dataset
ir'
rgb_path = r'C:Users13141Desktoplow_nir_dataset
gb'
lr_path = r'C:Users13141Desktoplow_nir_datasetx4
ir_lr'
bic_path = r'C:Users13141Desktoplow_nir_datasetx4
ir_bic'

image_list = os.listdir(nir_path)
image_list.sort()

for i in range(len(image_list)):
    count = 1
    nir_img_path = os.path.join(nir_path, image_list[i])
    bic_img_path = os.path.join(bic_path, image_list[i])
    lr_img_path = os.path.join(lr_path, image_list[i])
    rgb_img_path = os.path.join(rgb_path, image_list[i].replace('nir','rgb'))
    
    nir_img = cv2.imread(nir_img_path,cv2.IMREAD_GRAYSCALE)
    bic_img = cv2.imread(bic_img_path,cv2.IMREAD_GRAYSCALE)
    lr_img = cv2.imread(lr_img_path,cv2.IMREAD_GRAYSCALE)
    rgb_img = cv2.imread(rgb_img_path)

    nir_img = crop_mod(nir_img, 4)
    rgb_img = crop_mod(rgb_img, 4)
    
    h,w = nir_img.shape[0], nir_img.shape[1]
    x_min = list(range(0, (h-96), 48))
    y_min = list(range(0, (w-96), 48))

    for xx in x_min:
        for yy in y_min:
            nir_pat = nir_img[xx:(xx+96), yy:(yy+96)]
            bic_pat = bic_img[xx:(xx+96), yy:(yy+96)]
            rgb_pat = rgb_img[xx:(xx+96), yy:(yy+96),:]
            xx_l = int(xx/4)
            yy_l = int(yy/4)
            lr_pat = lr_img[xx_l:(xx_l+24), yy_l:(yy_l+24)]
            
            count_0 = 5-len(str(count))
            image_name = image_list[i].split('.')[0]+'_'+'0'*count_0+str(count)+'.tiff'
            print(image_name)
            cv2.imwrite(os.path.join(r'C:Users13141Desktoplow_nir_datasetx4patches_24
ir', 
                                     image_name), nir_pat)
            cv2.imwrite(os.path.join(r'C:Users13141Desktoplow_nir_datasetx4patches_24
ir_bic', 
                                     image_name), bic_pat)
            cv2.imwrite(os.path.join(r'C:Users13141Desktoplow_nir_datasetx4patches_24
ir_lr', 
                                     image_name), lr_pat)
            cv2.imwrite(os.path.join(r'C:Users13141Desktoplow_nir_datasetx4patches_24
gb', 
                                     image_name), rgb_pat)
            count+=1

二、 边缘提取Canny算子

import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
path =r'C:Users13141Desktoplow_nir_datasetx4patches_24
ir'

img_list = os.listdir(path)

count = 0
img_del = []
for i in range(len(img_list)):
    img_path = os.path.join(path, img_list[i])
    img = cv2.imread(img_path)
    img = cv2.GaussianBlur(img, (3,3), 0)
    canny1 = cv2.Canny(img,0,  55)
    l = canny1.sum()
#   记录0边缘图像
    if l==0:
        print(img_list[i])
        img_del.append(img_list[i])
        count +=1
#     plt.imshow(canny1, cmap='gray')
#     plt.show()

删除0边缘图像

path =r'C:Users13141Desktoplow_nir_datasetx4patches_24
ir'
bic_path = r'C:Users13141Desktoplow_nir_datasetx4patches_24
ir_bic'
lr_path = r'C:Users13141Desktoplow_nir_datasetx4patches_24
ir_lr'
rgb_path=r'C:Users13141Desktoplow_nir_datasetx4patches_24
gb'
for i in img_del:
    os.remove(os.path.join(path, i))
    os.remove(os.path.join(bic_path, i))
    os.remove(os.path.join(lr_path, i))
    os.remove(os.path.join(rgb_path, i))

三、 图像配准

import numpy as np
import cv2

def sift_kp(image):
    if image.shape[2]==3:
        gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    else:
        gray_image = image
    sift = cv2.xfeatures2d.SIFT_create()
    kp, des = sift.detectAndCompute(image, None)
    kp_image = cv2.drawKeypoints(gray_image, kp, None)
    return kp_image, kp, des

def get_good_match(des1, des2):
    bf = cv2.BFMatcher()
    matches = bf.knnMatch(des1, des2, k=2)
    good = []
    for m, n in matches:
        if m.distance<0.75*n.distance:
            good.append(m)
    return good

def siftImageAlignment(img1, img2):
    _, kp1, des1 = sift_kp(img1)
    _, kp2, des2 = sift_kp(img2)
    goodMatch = get_good_match(des1, des2)
    if len(goodMatch) >4:
        ptsA = np.float32([kp1[m.queryIdx].pt for m in goodMatch]).reshape(-1, 1, 2)
        ptsB = np.float32([kp2[m.trainIdx].pt for m in goodMatch]).reshape(-1, 1, 2)
        ransacReprojThreshold = 4
        H, status = cv2.findHomography(ptsA, ptsB, cv2.RANSAC, ransacReprojThreshold)
        imgOut = cv2.warpPerspective(img2, H, (img1.shape[1], img1.shape[0]), flags=cv2.INTER_LINEAR+cv2.WARP_INVERSE_MAP)
    return imgOut, H, status

img1 = cv2.imread(r'E:datasetsmultispectraldataRGB_NIRstereo4img2.png')
img2 = cv2.imread(r'E:datasetsmultispectraldataRGB_NIRstereo4img1.png')
while img1.shape[0]>1000 or img1.shape[1]>1000:
    img1 = cv2.resize(img1, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_AREA)
while img2.shape[0]>1000 or img2.shape[1]>1000:
    img2 = cv2.resize(img2, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_AREA)

print(img1.shape, img2.shape)
result, _, _ = siftImageAlignment(img1, img2)
allImage = np.concatenate((img1, img2, result), axis=1)

import matplotlib.pyplot as plt
plt.imshow(result)
plt.show()

四、 数据增强(Pytorch)

def augment(*args, hflip=True, rot=True):
    hflip = hflip and random.random() < 0.5
    vflip = rot and random.random() < 0.5
    rot90 = rot and random.random() < 0.5

    def _augment(img):
        if hflip: img = img[:, ::-1, :].copy()
        if vflip: img = img[:, :, ::-1].copy()
        if rot90: img = img.transpose(0, 2, 1).copy()
        return img

    return [_augment(a) for a in args]
原文地址:https://www.cnblogs.com/btschang/p/11445076.html