5.matplotlib绘制-meshgrid区域图-可视化ML

传入参数:

1)plt:调用对象中的pyplot。

如:import matplotlib.pyplot as plt

2)predict:调用对象中ML算法的predict函数,用于预测对X,Y构造网格后的预测。

如:clf = neighbors.KNeighborsClassifier(n_neighbors=15, weights='distance')

3)X, Y:绘图的参数,shape:nx1,n,1

4) classes_color:颜色列表,

 如:classes_color=['#FFAAAA', '#AAFFAA', '#AAAAFF','#00000']

5)step=0.05 ,网格细分长度

import numpy as np
from matplotlib.colors import ListedColormap

cmpcolor = ['#FFAAAA', '#AAFFAA', '#AAAAFF']


def create_meshgrid_pic(plt, predict, X, Y, classes_color=cmpcolor, step=0.05):
    # 确认训练集的边界
    x_min, x_max = X[:].min() - 1, X[:].max() + 1
    y_min, y_max = Y[:].min() - 1, Y[:].max() + 1
    # 生成网格数据,xx:所有网格点的x坐标,形状也是网格性nxm。yy同样
    xx, yy = np.meshgrid(np.arange(x_min, x_max, step),
                         np.arange(y_min, y_max, step))
    # xx,yy的扁平化成一串坐标点(密密麻麻的网格点平摊开来)
    d = np.c_[xx.ravel(), yy.ravel()]
    # 对网格点进行类型预测
    Z = predict(d)
    # 预测类型后,重新变回网格的样子,因为后面pcolormesh接收网格形式的绘图数据
    Z = Z.reshape(xx.shape)
    # 获取类型数量
    class_size = np.unique(Z).size
    if class_size > len(classes_color):
        print('颜色列表太少')
        return AttributeError
    classes_color = classes_color[:class_size]

    cmap_light = ListedColormap(classes_color)

    # 接收网格化的x,y,z
    plt.pcolormesh(xx, yy, Z, cmap=cmap_light)

使用:

import matplotlib.pyplot as plt
from sklearn import neighbors
from **** import create_meshgrid_pic

X, Y = ()
clf = neighbors.KNeighborsClassifier(n_neighbors=15, weights='distance')
clf.fit(X, Y)
cmap_light = ['#FFAAAA', '#AAFFAA', '#AF0000']
create_meshgrid_pic(plt, clf, X[:, 0], X[:, 1], cmap_light, 0.02)
plt.show()

原文地址:https://www.cnblogs.com/onenoteone/p/12441726.html