Python基于皮尔逊系数实现股票预测

 1 # -*- coding: utf-8 -*-
 2 """
 3 Created on Mon Dec  2 14:49:59 2018
 4 
 5 @author: zhen
 6 """
 7 
 8 import matplotlib.pyplot as plt
 9 import numpy as np
10 import pandas as pd
11 from datetime import datetime
12 
13 def normal(a):  #最大值最小值归一化
14     return (a - np.min(a)) / (np.max(a) - np.min(a)+0.000001)
15 
16 def normalization(x): # np.std:计算矩阵的标准差(方差的算术平方根)
17     return (x - np.mean(x)) / np.std(x)
18 
19 def corrcoef(a,b):
20     corrc = np.corrcoef(a,b) # 计算皮尔逊相关系数,用于度量两个变量之间的相关性,其值介于-1到1之间
21     corrc = corrc[0,1]
22     return (16 * ((1 - corrc) / (1 + corrc)) ** 1) # ** 表示乘方
23   
24 startTimeStamp = datetime.now() # 获取当前时间
25 # 加载数据
26 filename = 'C:/Users/zhen/.spyder-py3/sh000300_2017.csv'
27 # 获取第一,二列的数据
28 all_date = pd.read_csv(filename,usecols=[0, 1, 3], dtype = 'str')
29 all_date = np.array(all_date)
30 data = all_date[:, 0]
31 times = all_date[:, 1]
32 
33 data_points = pd.read_csv(filename,usecols=[3])
34 data_points = np.array(data_points)
35 data_points = data_points[:,0] #数据
36 
37 topk = 10 #只显示top-10
38 baselen = 100
39 basebegin = 361
40 basedata = data[basebegin]+' '+times[basebegin]+'~'+data[basebegin+baselen-1]+' '+times[basebegin+baselen-1]
41 base = data_points[basebegin:basebegin+baselen]#一天的数据是240个点
42 length = len(data_points) #数据长度
43 
44 # 分割片段
45 subseries = []
46 dateseries = []
47 for j in range(0,length): 
48     if (j < (basebegin - baselen) or j > (basebegin + baselen - 1)) and j <length - baselen:
49         subseries.append(data_points[j:j+baselen])
50         dateseries.append(j) #开始位置
51 
52 # 片段搜索
53 listdistance = []
54 for i in range(0, len(subseries)):
55     tt = np.array(subseries[i])
56     distance = corrcoef(base, tt)
57     listdistance.append(distance)
58 
59 # 排序
60 index = np.argsort(listdistance,kind='quicksort') #排序,返回排序后的索引序列
61 
62 # 显示,要匹配的数据
63 plt.figure(0)
64 plt.plot((base),label = basedata, linewidth='2')
65 plt.legend(loc='upper left')
66 plt.title('Base data')
67 
68 # 原始数据
69 plt.figure(1)
70 num = index[0]
71 length = len(subseries[num])
72 begin = data[dateseries[num]]+' '+times[dateseries[num]]
73 end = data[dateseries[num]+length-1]+' '+times[dateseries[num]+length-1]
74 label = begin+'~'+end
75 plt.plot((subseries[num]), label=label, linewidth='2')
76 plt.legend(loc='upper left')
77 plt.title('Similarity data')
78 
79 # 结果集对比
80 plt.figure(2)
81 plt.plot(normalization(base),label= basedata,linewidth='2')
82 length = len(subseries[num])
83 begin = data[dateseries[num]] + ' ' + times[dateseries[num]]
84 end = data[dateseries[num] + length - 1] + ' ' + times[dateseries[num] + length - 1]
85 label = begin + '~' + end
86 plt.plot(normalization(subseries[num]), label=label, linewidth='3')  
87 plt.legend(loc='lower right')
88 plt.title('normal similarity search')
89 plt.show()
90 
91 endTimeStamp=datetime.now()
92 print('run time', (endTimeStamp-startTimeStamp).seconds, "s")

结果:

原文地址:https://www.cnblogs.com/yszd/p/10058475.html