pytorch实战(二)hw2——预测收入是否高于50000,分类问题

代码和ppt:

https://github.com/Iallen520/lhy_DL_Hw

遇到的一些细节问题:

1. X_train文件不带后缀名csv,所以不是规范的csv文件,不能直接用pd.read_csv,否则发现第一行名有错误,所以用原始的方法去处理

2. 记着拆分train和test,是有必要的。

3. 数据类型转换,第一次是numpy array的转换,从str转到float,第二次是pytorch数据初始化时,注意:预测问题的y是用的二维数据,float类型。

但是分类问题,y应该是用一维数据,为torch.long类型,使用nn.CrossEntropyLoss(y_hat, y_label) !!!

提示可以用y.sqeeze() 压缩维度为1的那一维

4. 最后一层是nn.linear(k,2),映射成2维,分别代表两种类别的概率,然后用np.argmax(y, axis = 1) 提取出类别值

原文地址:https://www.cnblogs.com/qiezi-online/p/14015200.html