tensorflow基础【4】-计算图 graph

tensorflow,tensor就是数据,flow就是流,tensorflow就是数据流

tensorflow 是一个用计算图的形式来表示计算的编程系统,所有的数据和计算都会被转化成计算图上的一个节点,节点之间的边就是数据流(数据流动的轨迹)。

计算图的使用

1. 建立节点

2. 执行计算

计算图有两种形式

默认的计算图

tf 维护一个默认的计算图,

get_default_graph:获取默认计算图

graph:获取节点所属计算图

import tensorflow as tf

a = tf.constant([1., 2.], name = 'a')
b = tf.constant([2., 3.], name = 'b')
result = a + b

print(a.graph is tf.get_default_graph())            # True

数据本身就是节点,该节点的 graph 就是默认计算图

自定义计算图

tf.Graph 可以生成新的计算图,不同计算图之间的数据和计算不能共享

## g1
g1 = tf.Graph()
with g1.as_default():
    # 在计算图 g1 中定义变量 “v” ,并设置初始值为 0。
    v = tf.get_variable("v", [1], initializer = tf.zeros_initializer()) # 设置初始值为0,shape 为 1

## g2
g2 = tf.Graph()
with g2.as_default():
    # 在计算图 g2 中定义变量 “v” ,并设置初始值为 1。
    v = tf.get_variable("v", [1], initializer = tf.ones_initializer()) # 设置初始值为1


# 在计算图 g1 中读取变量“v” 的取值
with tf.Session(graph = g1) as sess:
    tf.global_variables_initializer().run()
    with tf.variable_scope("", reuse=True):
        print(sess.run(tf.get_variable("v")))               # [0.]

# 在计算图 g2 中读取变量“v” 的取值
with tf.Session(graph = g2) as sess:
    tf.global_variables_initializer().run()
    with tf.variable_scope("", reuse=True):
        print(sess.run(tf.get_variable("v")))               # [1.]

## g3
g3 = tf.Graph()
with g3.as_default():
    # 在计算图 g2 中定义变量 “v” ,并设置初始值为 1。
    v = tf.get_variable("v2", [1], initializer = tf.ones_initializer()) # 设置初始值为1


# 在计算图 g1 中读取变量“v” 的取值
with tf.Session(graph = g3) as sess:
    tf.global_variables_initializer().run()
    with tf.variable_scope("", reuse=True):
        print(sess.run(tf.get_variable("v2")))               # [1.]
        print(sess.run(tf.get_variable("v")))                # 报错 Variable v does not exist

可以看到 g3 无法调用 g2 中的变量v

计算图可以用来隔离张量和计算

计算图的操作

保存

g1 = tf.Graph()
with g1.as_default():
    # 需要加上名称,在读取pb文件的时候,是通过name和下标来取得对应的tensor的
    c1 = tf.constant(4.0, name='c1')

with tf.Session(graph=g1) as sess1:
    print(sess1.run(c1))                        # 4.0


# g1的图定义,包含pb的path, pb文件名,是否是文本默认False
tf.train.write_graph(g1.as_graph_def(),'.','graph.pb',False)

读取

import tensorflow as tf#load graph
with tf.gfile.FastGFile("./graph.pb",'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

sess = tf.Session()
c1_tensor = sess.graph.get_tensor_by_name("c1:0")
c1 = sess.run(c1_tensor)
print(c1)                       # 4.0

穿插调用

g1 = tf.Graph()
with g1.as_default():
    # 声明的变量有名称是一个好的习惯,方便以后使用
    c1 = tf.constant(4.0, name="c1")

g2 = tf.Graph()
with g2.as_default():
    c2 = tf.constant(20.0, name="c2")

with tf.Session(graph=g2) as sess1:
    # 通过名称和下标来得到相应的值
    c1_list = tf.import_graph_def(g1.as_graph_def(), return_elements = ["c1:0"], name = '')
    print(sess1.run(c1_list[0]+c2))             # 24.0

指定计算图的运行设备

g = tf.Graph()
# 指定计算运行的设备
with g.device('/gpu:0'):
    result = a + b

计算图资源管理

在一个计算图中,可以通过集合来管理不同的资源。

比如通过 tf.add_to_collection 将资源加入一个或多个集合中,然后通过 tf.get_collection 获取一个集合里的所有资源

参考资料:

https://www.cnblogs.com/q735613050/p/7632792.html

原文地址:https://www.cnblogs.com/yanshw/p/10622761.html