2.5数据-paddlepaddle数据集imikolov

imikolov的简化版数据集
此模块将从 http://www.fit.vutbr.cz/~imikolov/rnnlm/ 下载数据集,并将训练集和测试集解析为paddle reader creator


paddle.dataset.imikolov:https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/data/dataset_cn/imikolov_cn.html


import paddle.fluid as fluid
import numpy as np
import paddle
import paddle.dataset.imikolov as imikolov

# 取最小出现50次的约束下,的词表
word_idx=imikolov.build_dict(min_word_freq=50)
# list(word_idx.keys())[:10]:[b'the', b'<unk>', '<e>', '<s>', b'N', b'of', b'to', b'a', b'in', b'and']
'''
word_idx:
{
b'the': 0,
 b'<unk>': 1,
 '<e>': 2,
 '<s>': 3,
 b'N': 4,
 ....
 }
'''

# n (int) – 如果类型是ngram,表示滑窗大小;否则表示序列最大长度
# data_type (数据类型的成员变量(NGRAM 或 SEQ)) – 数据类型 (ngram 或 sequence)
# imikolov.py:
# class DataType(object):
#     NGRAM = 1
#     SEQ = 2
# N-gram级别
train_ngram_data=imikolov.train(word_idx=word_idx,n=5,data_type=imikolov.DataType.NGRAM)
# 句子级数据
train_seq_data=imikolov.train(word_idx=word_idx,n=5,data_type=imikolov.DataType.SEQ)
# train_ngram_data:<function paddle.dataset.imikolov.reader_creator.<locals>.reader()>
# train_seq_data:<function paddle.dataset.imikolov.reader_creator.<locals>.reader()>

BATCH_SIZE=64
train_ngram_reader=fluid.io.batch(paddle.reader.shuffle(train_ngram_data, buf_size=32),batch_size=BATCH_SIZE)

train_seq_reader=fluid.io.batch(paddle.reader.shuffle(train_seq_data, buf_size=32),batch_size=BATCH_SIZE)
for ngram_data in train_ngram_reader():
    break
for seq_data in train_seq_reader():
    break

# list(word_idx.keys())[2073]:'<unk>'
# len(ngram_data),ngram_data[:15]:64

# 滑动窗口n=5
# 5-gram的话,表示前4个词用来预测第5个词
ngram_data[:6]
'''
[(2073, 2073, 2073, 2073, 2073),
  (2073, 2073, 2073, 2073, 2073),
  (2073, 2073, 2073, 2073, 2073),
  (2073, 2073, 2073, 2073, 2073),
  (2073, 2073, 2073, 2073, 2073),
  (74, 390, 35, 2073, 0)]
 '''
# n 最大是5
# seq的数据,有2组数据,输入第一句,预测第二句?
seq_data[:6]
'''
([3, 2073, 510, 6, 842], [2073, 510, 6, 842, 2]),
  ([3, 1, 1094], [1, 1094, 2]),
  ([3, 519, 48, 33, 845], [519, 48, 33, 845, 2]),
  ([3, 461], [461, 2]),
  ([3, 36, 91, 305], [36, 91, 305, 2]),
  ([3, 12, 2073, 1431], [12, 2073, 1431, 2])]
'''
def show(idx):
    s=[str((list(word_idx.keys())[i])) for i in idx]
    return ' '.join(s)
show(ngram_data[5]) # "b'years' b'old' b'will' <unk> b'the'"
show(seq_data[5][0]) # "<s> b'that' <unk> b'attention'"

原文地址:https://www.cnblogs.com/onenoteone/p/12441675.html