input_fn如何读取数据

1,简单数pandas

import pandas

train_df = pandas.read_csv(train_file, header = None, names = census_dataset._CSV_COLUMNS)
test_df = pandas.read_csv(test_file, header = None, names = census_dataset._CSV_COLUMNS)

train_df.head()

2,从pandas的data frame中读取数据

ds = easy_input_function(train_df, label_key='income_bracket', num_epochs=5, shuffle=True, batch_size=10)

for feature_batch, label_batch in ds.take(1):
print('Some feature keys:', list(feature_batch.keys())[:5])
print()
print('A batch of Ages :', feature_batch['age'])
print()
print('A batch of Labels:', label_batch )

3,利用tf_data直接从文件系统读取数据

ds = census_dataset.input_fn(train_file, num_epochs=5, shuffle=True, batch_size=10)

for feature_batch, label_batch in ds.take(1):
print('Feature keys:', list(feature_batch.keys())[:5])
print()
print('Age batch :', feature_batch['age'])
print()
print('Label batch :', label_batch )

原文地址:https://www.cnblogs.com/augustone/p/10505931.html