tensorboard

import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt

# 清除
tf.reset_default_graph()
logdir = 'log'
plt.figure()

x_data = np.random.rand(1000).astype(np.float32)
y_data = x_data * 8 + 1.5

plt.plot(x_data, y_data)
plt.show()
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)

w = tf.Variable(tf.random_normal([1]), tf.float32)
b = tf.Variable(tf.random_normal([1]), tf.float32)

y_pre = tf.multiply(x, w) + b

loss = tf.reduce_sum(tf.pow((y - y_pre), 2))
train = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    for i in range(1000):
        sess.run(train, feed_dict={x: x_data[i], y: y_data[i]})
        if i % 50 == 0:
            print(sess.run(loss, feed_dict={x: x_data[i], y: y_data[i]}))

    writer = tf.summary.FileWriter(logdir, tf.get_default_graph())
    writer.close()

1. tf.reset_default_graph() 是清除default gragh 和不断增加的节点

2.定义一个writer,参数为log的地址,和图形这里我们直接用get_default_graph()来获得tensorflow默认生成的图

3.需要把write 关闭了

还可以使用

writer = tf.summary.FileWriter('log')

然后

writer.add_graph(sess.graph)

原文地址:https://www.cnblogs.com/francischeng/p/9813037.html