tf.identity()函数解析(最清晰的解释)

欢迎关注WX公众号:【程序员管小亮】

这两天看batch normalization的代码时,学到滑动平均窗口函数ExponentialMovingAverage时,碰到一个函数tf.identity()函数,特此记录。

tf.identity()函数用于返回一个和input一样的新的tensor。

tf.identity(
	input,
	name=None
)
#Return a tensor with the same shape and contents as input.
#返回一个tensor,contents和shape都和input的一样

简单来说,就是返回一个和input一样的新的tensor。

例子1:

import tensorflow as tf
w = tf.Variable(1.0)
ema = tf.train.ExponentialMovingAverage(0.9)
update = tf.assign_add(w, 1.0)

ema_op = ema.apply([update])
with tf.control_dependencies([ema_op]):
    ema_val = ema.average(update)

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    for i in range(3):
        print(sess.run([ema_val]))
> [0.0]
> [0.0]
> [0.0]

理想的情况下,在我们 sess.run([ema_val]), ema_op 都会被先执行,然后再计算ema_val,实际情况并不是这样,为什么?

有兴趣的可以看一下源码,就会发现 ema.average(update) 不是一个 op,它只是从ema对象的一个字典中取出键对应的 tensor而已,然后赋值给ema_val。这个 tensor是由一个在 tf.control_dependencies([ema_op]) 外部的一个 op 计算得来的,所以control_dependencies会失效。解决方法也很简单,看代码:

import tensorflow as tf
w = tf.Variable(1.0)
ema = tf.train.ExponentialMovingAverage(0.9)
update = tf.assign_add(w, 1.0)

ema_op = ema.apply([update])
with tf.control_dependencies([ema_op]):
    ema_val = tf.identity(ema.average(update)) #一个identity搞定

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    for i in range(3):
        print(sess.run([ema_val]))
> [0.20000005]
> [0.4800001]
> [0.8320002]

例子2:

import tensorflow as tf

x = tf.Variable(0.0)
x_plus_1 = tf.assign_add(x, 1) # 对x进行加1,x_plus_l是个op
with tf.control_dependencies([x_plus_1]):
    y = x
init = tf.global_variables_initializer()
with tf.Session() as session:
    init.run() # 相当于session.run(init)
    for i in range(5):
        print(y.eval()) # y.eval()这个相当于session.run(y)
> 0.0
  0.0
  0.0
  0.0
  0.0

理想的情况下,输出应该是:[1.0, 2.0, 3.0, 4.0, 5.0],实际情况并不是这样,为什么?

1 tf.control_dependencies()是一个在Graph上的operation,所以要想使得其参数起作用,就需要for循环里面利用sess.run()来执行;

2 y = x只是一个简单的赋值操作,而with tf.control_dependencies()作用域(也就是冒号下的代码行)只对op起作用,所以需要将tensor利用tf.identity()来转化为op。

针对以上原因,给出两个相应的解决方法:

1.
import tensorflow as tf

x = tf.Variable(0.0)
x_plus_1 = tf.assign_add(x, 1) # 对x进行加1,x_plus_l是个op
with tf.control_dependencies([x_plus_1]):
    y = x
init = tf.global_variables_initializer()
with tf.Session() as sess:
    init.run() # 相当于session.run(init)
    for i in range(5):
        sess.run(x_plus_1)
        print(y.eval()) # y.eval()这个相当于session.run(y)
> 1.0
   2.0
   3.0
   4.0
   5.0
2.
import tensorflow as tf

x = tf.Variable(0.0)
x_plus_1 = tf.assign_add(x, 1) # 对x进行加1,x_plus_l是个op
with tf.control_dependencies([x_plus_1]):
    y = tf.identity(x)
init = tf.global_variables_initializer()
with tf.Session() as sess:
    init.run() # 相当于session.run(init)
    for i in range(5):
        print(y.eval()) # y.eval()这个相当于session.run(y)
> 1.0
   2.0
   3.0
   4.0
   5.0

Graph上不论是tensor还是operation的更新都要借助op来进行,而将一个tensor转化为op最简单的方法就是tf.identity()。

python课程推荐。
在这里插入图片描述

参考文章:

tensorflow学习笔记(四十一):control dependencies
tf.control_dependencies()和tf.identity()

原文地址:https://www.cnblogs.com/hzcya1995/p/13302847.html