tensorflow 中 inter_op 和 intra_op

[root@localhost custom-resnet-v2]# cat runme.sh
#python demo_slim.py -h
#python demo_slim.py --cpu_num 8 --inter_op_threads 1 --intra_op_threads 8 --dump_timeline True

# export KMP_AFFINITY=verbose,granularity=fine,proclist=[0,1,2,3],explicit
# numactl -C 0-3 python demo_slim.py --cpu_num 4 --inter_op_threads 1 --intra_op_threads 4 >& run1.log &

export OMP_NUM_THREADS=8
python demo_slim.py --cpu_num 8 --inter_op_threads 1 --intra_op_threads 8

[root@localhost custom-resnet-v2]# cat demo_slim.py
# coding: utf8
import os
import sys

import numpy as np
import tensorflow as tf
from tensorflow.python.client import timeline
import argparse
import time


def make_fake_input(batch_size, input_height, input_width, input_channel):
        im = np.zeros((input_height,input_width,input_channel), np.uint8)
        im[:,:,:] = 1
        images = np.zeros((batch_size, input_height, input_width, input_channel), dtype=np.float32)
        for i in xrange(batch_size):
                images[i, 0:im.shape[0], 0:im.shape[1], :] = im
                #channel_swap = (0, 3, 1, 2)  # caffe
                #images = np.transpose(images, channel_swap)
                #cv2.imwrite("test.jpg", im)
        return images


def get_parser():
        """
        create a parser to parse argument "--cpu_num --inter_op_threads --intra_op_threads"
        """
        parser = argparse.ArgumentParser(description="Specify tensorflow parallelism")
        parser.add_argument("--cpu_num", dest="cpu_num", default=1, help="specify how many cpus to use.(default: 1)")
        parser.add_argument("--inter_op_threads", dest="inter_op_threads", default=1, help="specify max inter op parallelism.(default: 1)")
        parser.add_argument("--intra_op_threads", dest="intra_op_threads", default=1, help="specify max intra op parallelism.(default: 1)")
        parser.add_argument("--dump_timeline", dest="dump_timeline", default=False, help="specify to dump timeline.(default: False)")
        return parser


def main():

        parser = get_parser()
        args = parser.parse_args()
        #parser.print_help()
        cpu_num = int(args.cpu_num)
        inter_op_threads = int(args.inter_op_threads)
        intra_op_threads = int(args.intra_op_threads)
        dump_timeline = bool(args.dump_timeline)
        print("cpu_num: ", cpu_num)
        print("inter_op_threads: ", inter_op_threads)
        print("intra_op_threads: ", intra_op_threads)
        print("dump_timeline: ", dump_timeline)


        config = tf.ConfigProto(device_count={"CPU": cpu_num}, # limit to num_cpu_core CPU usage
                inter_op_parallelism_threads = inter_op_threads,
                intra_op_parallelism_threads = intra_op_threads,
                log_device_placement=False)
        with tf.Session(config = config) as sess:
                imgs = make_fake_input(1, 224, 224, 3)
                #init_start = time.time()
                saver = tf.train.import_meta_graph("slim_model/slim_model.ckpt.meta")
                saver.restore(sess, tf.train.latest_checkpoint("slim_model/"))


                ## Operations
                #for op in tf.get_default_graph().get_operations():
                #       print(op.name)
                #       print(op.values())


                graph = tf.get_default_graph()
                input_data = graph.get_tensor_by_name("Placeholder:0")
                fc6 = graph.get_tensor_by_name("resnet_v2/avg_fc_fc6_Conv2D/BiasAdd:0")
                #init_end = time.time()
                #print("initialization time: ", init_end-init_start, "s")

                time_start = time.time()
                for step in range(200):
                        if dump_timeline:
                                run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                                run_metadata = tf.RunMetadata()
                                result = sess.run(fc6, feed_dict={input_data:imgs}, options=run_options, run_metadata=run_metadata)
                                tm = timeline.Timeline(run_metadata.step_stats)
                                ctf = tm.generate_chrome_trace_format()
                                with open('timeline.json', 'w') as f:
                                        f.write(ctf)
                        else:
                                result = sess.run(fc6, feed_dict={input_data:imgs})
                        print(result[0][0][0])
                time_end = time.time()
                avg_time = (time_end-time_start) * 1000 / 200;
                print("AVG Time: ", avg_time, " ms")
        return 0


if __name__ == "__main__":
        sys.exit(main())

原文地址:https://www.cnblogs.com/qccz123456/p/11676026.html