tf.while_loop

while 循环

def while_loop(cond,          ### 一个函数,负责判断循环是否进行
               body,          ### 一个函数,循环体,更新变量
               loop_vars,     ### 初始循环变量,可以是多个,这些变量是 cond、body 的输入 和输出
               shape_invariants=None,
               parallel_iterations=10,
               back_prop=True,
               swap_memory=False,
               name=None,
               maximum_iterations=None,
               return_same_structure=False):

返回 迭代后的 loop_vars

def cond(i, n):
    return i < n

def body(i, n):
    i = i + 1
    return i, n

i = tf.get_variable("ii", dtype=tf.int32, shape=[], initializer=tf.ones_initializer())
# i = 1                 # 也可以
# i = tf.constant(1)    # 也可以
n = tf.constant(10)
i, n = tf.while_loop(cond, body, [i, n])
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    res = sess.run([i, n])
    print(res)      # [10, 10]

注意:cond 和 body 的输入和输出要相同,且等于 loop_vars,即使在函数中没有用到全部的 loop_vars,也要做为输入和输出

参考资料:

原文地址:https://www.cnblogs.com/yanshw/p/12376823.html