python 读取libsvm文件

以下三种方式

# -*- coding:utf-8 -*-
import numpy as np
import os
from sklearn import datasets


def data_generator(input_filename, batch_size):
    """
    :param input_filename:
    :param batch_size:
    :return:
    """
    feature_size = 3
    labels = np.zeros(batch_size)
    rets = np.empty(shape=[batch_size, feature_size])
    i = 0
    for line in open(input_filename, "r"):
        data = line.split(" ")
        label = int(float(data[0]))
        ids = []
        values = []
        for fea in data[1:]:
            id, value = fea.split(":")
            if int(id) > feature_size - 1:
                break
            ids.append(int(id))
            values.append(float(value))
        ret = np.zeros([1, feature_size])
        for (index, d) in zip(ids, values):
            ret[0][index] = d
        labels[i] = int(label)
        rets[i] = ret
        i += 1
        if i > batch_size - 1:
            i = 0
            yield labels, rets[0:, 0:3]


def get_data(input_filename, batch_size):
    oneline = 16294  # 一行多少个字节
    feature_size = 1947
    batch = 0
    while True:
        data = datasets.load_svmlight_file(input_filename, offset=oneline * batch_size * batch,
                                           length=oneline * batch_size,
                                           n_features=feature_size)
        features = data[0]
        labels = data[1]
        if features.shape[0] > 0:  # 保证返回和数据的有效性
            batch += 1
            yield labels, features[0:, 0:3]
        else:
            raise StopIteration


def get_data_all(input_filename, batch_size):
    data = datasets.load_svmlight_file(input_filename)
    features = data[0]
    labels = data[1]
    batch = 0
    while True:
        start_index = batch * batch_size
        end_index = (batch + 1) * batch_size

        if features.shape[0] > end_index:
            yield labels[start_index:end_index], features[start_index:end_index, 0:3]
            batch += 1
        else:
            raise StopIteration


if __name__ == "__main__":
    print("====", os.getcwd())
    filename = "/home/part-00000"
    generator = data_generator(filename, 10)
    labels, features = generator.next()
    print([labels])
    print(features)

    generator = get_data_all(filename, 1000)
    while True:
        labels, features = generator.next()
        print 'data', len(labels), features.shape

对于需要循环多次调用方法的,可以使用缓存,需要注意的是,缓存不能直接加在yiled函数上

# -*- coding:utf-8 -*-
import numpy as np
from sklearn.externals.joblib import Memory
import os
import random
from sklearn import datasets

mem = Memory("/tmp/mycache")


def get_data_batch(input_filename, batch_size):
    data = get_data(input_filename)
    features = data[0]
    labels = data[1]
    batch = 0
    while True:
        start_index = batch * batch_size
        end_index = (batch + 1) * batch_size

        if features.shape[0] > end_index:
            yield labels[start_index:end_index], features[start_index:end_index]
            batch += 1
        else:
            raise StopIteration


@mem.cache
def get_data(input_filename):
    return datasets.load_svmlight_file(input_filename)
原文地址:https://www.cnblogs.com/tengpan-cn/p/8417739.html