获取bert所有隐层的输出

https://github.com/huggingface/transformers/issues/1827

from transformers import BertModel, BertConfig

config = BertConfig.from_pretrained("xxx", output_hidden_states=True)
model = BertModel.from_pretrained("xxx", config=config)

outputs = model(inputs)
print(len(outputs))  # 3

hidden_states = outputs[2]
print(len(hidden_states))  # 13

embedding_output = hidden_states[0]
attention_hidden_states = hidden_states[1:]

the returns of the BERT model are (last_hidden_state, pooler_output, hidden_states[optional], attentions[optional])

output[0] is therefore the last hidden state and output[1] is the pooler output.

原文地址:https://www.cnblogs.com/BlueBlueSea/p/13736223.html