tensorflow的hello world

  1 import tensorflow as tf;
  2 from tensorflow.examples.tutorials.mnist import input_data
  3 
  4 ##定义网络结构
  5 input_nodes  = 784
  6 output_nodes = 10
  7 layer1_nodes = 500
  8 #定义超参数
  9 #自动设置学习率
 10 learning_rate_base=  0.8;
 11 learning_decay = 0.99   ;
 12 decay_step=100          ;
 13 
 14 #滑动平均
 15 moving_average__decay = 0.99
 16 regularizer_rate  = 0.0001;
 17 train_step=30000
 18 batch_size= 100
 19 
 20 
 21 def inference(tensor1,weight1,bias1,weight2,bias2,average_class=None):
 22     if(average_class==None):
 23         layer1=tf.nn.relu(   tf.matmul(tensor1,weight1)+ bias1 )
 24         return tf.matmul( layer1,weight2 ) + bias2
 25     else:
 26         layer1 = tf.nn.relu(tf.matmul(tensor1, average_class.average(weight1)) + average_class.average(bias1))
 27         return tf.matmul(layer1, average_class.average(weight2) ) + average_class.average(bias2)
 28 
 29 def get_weight(shape):
 30     weight=tf.Variable(tf.truncated_normal(shape=shape,stddev=0.1),tf.float32)
 31     tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer_rate)(weight))
 32     return weight
 33 
 34 def get_bias(shape):
 35     return tf.Variable(tf.zeros(shape))
 36 
 37 def train(mnist):
 38     #定义输入输出
 39     train_x=tf.placeholder(tf.float32,shape=[None,input_nodes],name='train_x')
 40     train_y=tf.placeholder(tf.float32,shape=[None,output_nodes],name='train_y' )
 41 
 42     weight1=get_weight( [input_nodes,layer1_nodes] )
 43     bias1   =get_bias([layer1_nodes])
 44 
 45     weight2=get_weight([layer1_nodes,output_nodes]);
 46     bias2  =get_bias([output_nodes])
 47     results = inference(train_x, weight1, bias1, weight2, bias2, None)
 48 
 49     #定义学习率
 50     global_step = tf.Variable(0, trainable=False)
 51     learning_rate = tf.train.exponential_decay(learning_rate_base, global_step,  mnist.train.num_examples / batch_size, learning_decay,staircase=True)
 52 
 53     #定义损失、优化器
 54 
 55     ce= tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=results,labels=tf.argmax( train_y,1) ) )
 56     loss=ce+tf.add_n( tf.get_collection('losses') )
 57     tf.summary.scalar('lost',loss)
 58 
 59     optimizer= tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step);
 60 
 61     #定义滑动平均
 62     ema = tf.train.ExponentialMovingAverage(moving_average__decay, global_step);
 63     maintain_average_op = ema.apply( tf.trainable_variables())
 64     with tf.control_dependencies([optimizer,maintain_average_op]):
 65         train_op=tf.no_op(name='train')
 66 
 67     #预测准确率
 68     average_y=inference(train_x,weight1,bias1,weight2,bias2,ema);
 69     correction_prediction = tf.equal(  tf.argmax( average_y,1 ) ,tf.argmax(train_y,1))
 70     accuracy = tf.reduce_mean(tf.cast(correction_prediction,tf.float32));
 71 
 72     with tf.Session() as sess:
 73         tf.global_variables_initializer().run()
 74 
 75         validate_feed={train_x:mnist.validation.images,train_y:mnist.validation.labels}
 76         test_feed    ={train_x:mnist.test.images,train_y:mnist.test.labels}
 77 
 78         #汇总
 79         merged_summary_op = tf.summary.merge_all()
 80         summaryWriter = tf.summary.FileWriter('./log/mnist_with_summaries',sess.graph)
 81 
 82         #迭代训练
 83         for i in range(train_step):
 84             if(i%1000 == 0 ):
 85                 validate_acc=sess.run(accuracy,feed_dict=validate_feed);
 86                 print('After %d training steps,using aaverage model is %g '%(i,validate_acc))
 87 
 88             xt,yt=mnist.train.next_batch(batch_size);
 89             sess.run( train_op,feed_dict={ train_x :xt,train_y:yt}          );
 90             summary_str=sess.run( merged_summary_op,feed_dict={ train_x :xt,train_y:yt} );
 91             summaryWriter.add_summary(summary_str,i)
 92 
 93 
 94         test_acc=sess.run(accuracy,feed_dict=test_feed)
 95         print('accuracy is %g'%(test_acc));
 96 def main():
 97     mnist= input_data.read_data_sets('./MNIST_data',one_hot=True)
 98     train(mnist);
 99 
100 if __name__ == '__main__':
101     main()
原文地址:https://www.cnblogs.com/z-bear/p/10455547.html