数据不平衡

1、PCA降维

 pca = PCA(n_components=20)
 trans = pca.fit(train_data1)
 train_data = trans.transform(train_data1)

2、class-weight 设置了这个参数后,会自动设置class weight让每类的sample对损失的贡献相等

model = SVM.SVC(kernel='rbf', C=0.1, gamma=1,class_weight = 'balanced').fit(train_data, train_label)

3、欠采样方法1:RandomUnderSampler,函数是一种快速并十分简单的方式来平衡各个类别的数据: 随机选取数据的子集.

 rus = RandomUnderSampler(random_state=0)
 train_data, train_label = rus.fit_sample(train_data1, train_label)
 print(train_data.shape)

4、欠采样方法2:Under-sampling,给定数据集S, 原型生成算法将生成一个子集S’, 其中|S’| < |S|, 但是子集并非来自于原始数据集. 意思就是说: 原型生成方法将减少数据集的样本数量, 剩下的样本是由原始数据集生成的, 而不是直接来源于原始数据集.ClusterCentroids函数实现了上述功能: 每一个类别的样本都会用K-Means算法的中心点来进行合成, 而不是随机从原始样本进行抽取.

 cc = ClusterCentroids(random_state=0)
 train_data, train_label = cc.fit_sample(train_data1, train_label)
 print(train_data.shape)

5、适当减少数据,防止过拟合,选取十分之一——0.1的数据

train_data,_ , train_label, _ = train_test_split(train_data0, train_label0, test_size=0.9)

还可以采用过采样:

相对于采样随机的方法进行过采样, 还有两种比较流行的采样少数类的方法: (i) Synthetic Minority Oversampling Technique (SMOTE); (ii) Adaptive Synthetic (ADASYN) .

SMOTE: 对于少数类样本a, 随机选择一个最近邻的样本b, 然后从a与b的连线上随机选取一个点c作为新的少数类样本;

ADASYN: 关注的是在那些基于K最近邻分类器被错误分类的原始样本附近生成新的少数类样本

from imblearn.over_sampling import SMOTE, ADASYN

X_resampled_smote, y_resampled_smote = SMOTE().fit_sample(X, y)

sorted(Counter(y_resampled_smote).items())
Out[29]:
[(0, 4674), (1, 4674), (2, 4674)]

X_resampled_adasyn, y_resampled_adasyn = ADASYN().fit_sample(X, y)

sorted(Counter(y_resampled_adasyn).items())
Out[30]:
[(0, 4674), (1, 4674), (2, 4674)]
#coding=utf-8

import numpy as np
import pickle
from sklearn.metrics import classification_report
import sklearn.svm as SVM
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from imblearn.under_sampling import ClusterCentroids
from imblearn.under_sampling import RandomUnderSampler
#数据集装载函数
def load_data(fname):
    with open(fname, 'rb') as fr:
        ret = pickle.load(fr)
    return ret

def main():
    #装载训练数据
    train_data0, train_label0 = load_data('/home/hd_1T/haiou/class/machinelearning/data/data1/train_data.pkl')
    t2_data,t2_label = load_data('/home/hd_1T/haiou/class/machinelearning/data/data2/train_data.pkl')
    #train_data0 = train_data0 + t2_data
    #train_label0 = train_label0 + t2_label
    train_data,_ , train_label, _ = train_test_split(train_data0, train_label0, test_size=0.9)
    print(train_data.shape)
    print(train_label.shape)
    train_data1 = train_data.reshape((train_data.shape[0], train_data.shape[1]*train_data.shape[2]))
    '''pca'''########################
    #pca = PCA(n_components=20)
    #trans = pca.fit(train_data1)
    #train_data = trans.transform(train_data1)
    
    '''Under-sampling1'''
    cc = ClusterCentroids(random_state=0)
    train_data, train_label = cc.fit_sample(train_data1, train_label)
    print(train_data.shape)
   
    '''RandomUnderSampler'''
    #rus = RandomUnderSampler(random_state=0)
    #train_data, train_label = rus.fit_sample(train_data1, train_label)
    #print(train_data.shape)
    #######################################################
    #TODO: 构建和训练你的模型
    #
    #
    #train_data = train_data1
    #model = SVM.SVC(kernel='rbf', C=0.1, gamma=1).fit(train_data, train_label)
    model = SVM.SVC(kernel='rbf', C=0.1, gamma=1,class_weight = 'balanced').fit(train_data, train_label)
    pred = model.predict(train_data)
    print('report on training data:')
    print(classification_report(train_label, pred))
    #######################################################
    #装载测试数据
    test_data, test_label = load_data('/home/hd_1T/haiou/class/machinelearning/data/data1/test_data.pkl')
    print(test_data.shape)
    print(test_label.shape)
    test_data1 = test_data.reshape((test_data.shape[0], test_data.shape[1]*test_data.shape[2]))    
    #test_data = trans.transform(test_data1)
    test_data = test_data1

    #######################################################
    #TODO :在训练好的模型上,预测test_data, 并将测试结果与test_label比较,报告测试结果
    #
    #
    pred = model.predict(test_data)
    print('report on test data:')
    print(classification_report(test_label, pred))
    #######################################################

if __name__ == '__main__':
    main()

参考:

https://blog.csdn.net/kizgel/article/details/78553009?locationNum=6&fps=1#213-smote的变体

原文地址:https://www.cnblogs.com/hozhangel/p/11101617.html