基于tensor2tensor的注意力可视化

根据训练好的Transformer模型,得到注意力矩阵,并对注意力进行可视化

首先安装:tensorflow 1.13.1 + tensor2tensor 1.13.1

可视化,请在Jupyter notebook中运行。该代码根据tensor2tensor/tensor2tensor/visualization/visualization.py修改得到

# coding=utf-8
# Copyright 2020 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Shared code for visualizing transformer attentions."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np

# To register the hparams set
from tensor2tensor import models  # pylint: disable=unused-import
from tensor2tensor import problems
from tensor2tensor.utils import registry
from tensor2tensor.utils import trainer_lib

import tensorflow.compat.v1 as tf
from tensor2tensor.utils import usr_dir
EOS_ID = 1

class AttentionVisualizer2(object):
  """Helper object for creating Attention visualizations."""

  def __init__(
      self, hparams_set,hparams,t2t_usr_dir, model_name, data_dir, problem_name, beam_size=1):
    inputs, targets, samples, att_mats = build_model(
        hparams_set,hparams, t2t_usr_dir, model_name, data_dir, problem_name, beam_size=beam_size)

    # Fetch the problem
    ende_problem = problems.problem(problem_name)
    encoders = ende_problem.feature_encoders(data_dir)

    self.inputs = inputs
    self.targets = targets
    self.att_mats = att_mats
    self.samples = samples
    self.encoders = encoders

  def encode(self, input_str):
    """Input str to features dict, ready for inference."""
    inputs = self.encoders["inputs"].encode(input_str) + [EOS_ID]
    batch_inputs = np.reshape(inputs, [1, -1, 1, 1])  # Make it 3D.
    return batch_inputs

  def decode(self, integers):
    """List of ints to str."""
    integers = list(np.squeeze(integers))
    return self.encoders["targets"].decode(integers)

  def encode_list(self, integers):
    """List of ints to list of str."""
    integers = list(np.squeeze(integers))
    return self.encoders["inputs"].decode_list(integers)

  def decode_list(self, integers):
    """List of ints to list of str."""
    integers = list(np.squeeze(integers))
    return self.encoders["targets"].decode_list(integers)

  def get_vis_data_from_string(self, sess, input_string):
    """Constructs the data needed for visualizing attentions.
    Args:
      sess: A tf.Session object.
      input_string: The input sentence to be translated and visualized.
    Returns:
      Tuple of (
          output_string: The translated sentence.
          input_list: Tokenized input sentence.
          output_list: Tokenized translation.
          att_mats: Tuple of attention matrices; (
              enc_atts: Encoder self attention weights.
                A list of `num_layers` numpy arrays of size
                (batch_size, num_heads, inp_len, inp_len)
              dec_atts: Decoder self attention weights.
                A list of `num_layers` numpy arrays of size
                (batch_size, num_heads, out_len, out_len)
              encdec_atts: Encoder-Decoder attention weights.
                A list of `num_layers` numpy arrays of size
                (batch_size, num_heads, out_len, inp_len)
          )
    """
    encoded_inputs = self.encode(input_string)

    # Run inference graph to get the translation.
    out = sess.run(self.samples, {
        self.inputs: encoded_inputs,
    })



    # Run the decoded translation through the training graph to get the
    # attention tensors.


    att_mats = sess.run(self.att_mats, {
        self.inputs: encoded_inputs,
        self.targets: np.reshape(out, [1, -1, 1, 1]),
    })

    output_string = self.decode(out)
    input_list = self.encode_list(encoded_inputs)
    output_list = self.decode_list(out)

    return output_string, input_list, output_list, att_mats


def build_model(hparams_set, hparams,t2t_usr_dir, model_name, data_dir, problem_name, beam_size=1):
  """Build the graph required to fetch the attention weights.
  Args:
    hparams_set: HParams set to build the model with.
    model_name: Name of model.
    data_dir: Path to directory containing training data.
    problem_name: Name of problem.
    beam_size: (Optional) Number of beams to use when decoding a translation.
        If set to 1 (default) then greedy decoding is used.
  Returns:
    Tuple of (
        inputs: Input placeholder to feed in ids to be translated.
        targets: Targets placeholder to feed to translation when fetching
            attention weights.
        samples: Tensor representing the ids of the translation.
        att_mats: Tensors representing the attention weights.
    )
  """
  print(model_name)
  usr_dir.import_usr_dir(t2t_usr_dir)
  hparams = trainer_lib.create_hparams(
      hparams_set,hparams, data_dir=data_dir, problem_name=problem_name)

  # print(hparams)

  translate_model = registry.model(model_name)(
      hparams, tf.estimator.ModeKeys.EVAL)

  inputs = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="inputs")
  targets = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="targets")
  translate_model({
      "inputs": inputs,
      "targets": targets,
  })

  # Must be called after building the training graph, so that the dict will
  # have been filled with the attention tensors. BUT before creating the
  # inference graph otherwise the dict will be filled with tensors from
  # inside a tf.while_loop from decoding and are marked unfetchable.
  atts = get_att_mats(translate_model,model_name)

  with tf.variable_scope(tf.get_variable_scope(), reuse=True):
    samples = translate_model.infer({
        "inputs": inputs,
    }, beam_size=beam_size)["outputs"]

  return inputs, targets, samples, atts


def get_att_mats(translate_model,model_name):
  """Get's the tensors representing the attentions from a build model.
  The attentions are stored in a dict on the Transformer object while building
  the graph.
  Args:
    translate_model: Transformer object to fetch the attention weights from.
  Returns:
  Tuple of attention matrices; (
      enc_atts: Encoder self attention weights.
        A list of `num_layers` numpy arrays of size
        (batch_size, num_heads, inp_len, inp_len)
      dec_atts: Decoder self attetnion weights.
        A list of `num_layers` numpy arrays of size
        (batch_size, num_heads, out_len, out_len)
      encdec_atts: Encoder-Decoder attention weights.
        A list of `num_layers` numpy arrays of size
        (batch_size, num_heads, out_len, inp_len)
  )
  """
  enc_atts = []
  dec_atts = []
  encdec_atts = []

  prefix = "%s/body/"%(model_name)
  postfix_self_attention = "/multihead_attention/dot_product_attention"
  if translate_model.hparams.self_attention_type == "dot_product_relative":
    postfix_self_attention = ("/multihead_attention/"
                              "dot_product_attention_relative")
  postfix_encdec = "/multihead_attention/dot_product_attention"

  for i in range(translate_model.hparams.num_hidden_layers):
    enc_att = translate_model.attention_weights[
        "%sencoder/layer_%i/self_attention%s"
        % (prefix, i, postfix_self_attention)]
    dec_att = translate_model.attention_weights[
        "%sdecoder/layer_%i/self_attention%s"
        % (prefix, i, postfix_self_attention)]
    encdec_att = translate_model.attention_weights[
        "%sdecoder/layer_%i/encdec_attention%s" % (prefix, i, postfix_encdec)]
    enc_atts.append(enc_att)
    dec_atts.append(dec_att)
    encdec_atts.append(encdec_att)

  return enc_atts, dec_atts, encdec_atts



from IPython.display import display
def call_html():
  import IPython
  display(IPython.core.display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              "d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
              jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
            },
          });
        </script>
        '''))


import os
from tensor2tensor import problems
from tensor2tensor.bin import t2t_decoder  # To register the hparams set
# from tensor2tensor.utils import registry
from tensor2tensor.utils import trainer_lib
from tensor2tensor.visualization import attention
# from src.visualization import visualization
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

# HParams
problem_name = 'translate_ende_wmt32k' #数据
data_dir = os.path.expanduser('/home/usrname/collaboration/t2t_data/%s'%(problem_name))  #数据路径
model_name = "collaboration"  #模型名称
hparams_set = "collaboration_base" #模型类型
hparams = 'max_length=128,num_hidden_layers=6,usedegray=1.0,reuse_n=0'  #自定义参数 (根据自己需求)
t2t_usr_dir = './src/' #用户自定义模型model的路径

visualizer = AttentionVisualizer2(hparams_set,hparams, t2t_usr_dir,model_name, data_dir, problem_name, beam_size=1)

tf.Variable(0, dtype=tf.int64, trainable=False, name='global_step')

接着继续运行:

saver = tf.train.Saver()
with tf.Session() as sess:
  ckpt = 'averaged.ckpt-0'  #checkpoint路径
  print(ckpt)
  saver.restore(sess, ckpt)

#可视化样本 # input_sentence = "It is in this spirit that a majority of American governments have passed new laws since 2009 making the registration or voting process more difficult." input_sentence = "The Law will never be perfect, but its application should be just - this is what we are missing, in my opinion." output_string, inp_text, out_text, att_mats = visualizer.get_vis_data_from_string(sess, input_sentence) print(output_string) call_html() attention.show(inp_text, out_text, *att_mats)

可视化结果:

  

原文地址:https://www.cnblogs.com/huadongw/p/14195355.html