GraphSAGE 代码解析(二)

 1 # global unique layer ID dictionary for layer name assignment
 2 _LAYER_UIDS = {}
 3 
 4 def get_layer_uid(layer_name=''):
 5     """Helper function, assigns unique layer IDs."""
 6     if layer_name not in _LAYER_UIDS:
 7         _LAYER_UIDS[layer_name] = 1
 8         return 1
 9     else:
10         _LAYER_UIDS[layer_name] += 1
11         return _LAYER_UIDS[layer_name]

这里_LAYER_UIDS = {} 是记录layer及其出现次数的字典。

在 get_layer_uid()函数中,若layer_name从未出现过,如今出现了,则将_LAYER_UIDS[layer_name]设为1;否则累加。

作用: 在class Layer中,当未赋variable scope的name时,通过实例化Layer的次数来标定不同的layer_id.

例子:简化一下class Layer可以看出:

 1 class Layer():
 2     def __init__(self):
 3         layer = self.__class__.__name__
 4         name = layer + '_' + str(get_layer_uid(layer))
 5         print(name) 
 6 
 7 layer1 = Layer()
 8 layer2 = Layer()
 9 
10 # Output:
11 # Layer_1
12 # Layer_2
View Code

 2. class Layer

class Layer主要定义基本的层的API。

 1 class Layer(object):
 2     """Base layer class. Defines basic API for all layer objects.
 3     Implementation inspired by keras (http://keras.io).
 4     # Properties
 5         name: String, defines the variable scope of the layer.
 6         logging: Boolean, switches Tensorflow histogram logging on/off
 7 
 8     # Methods
 9         _call(inputs): Defines computation graph of layer
10             (i.e. takes input, returns output)
11         __call__(inputs): Wrapper for _call()
12         _log_vars(): Log all variables
13     """
14 
15     def __init__(self, **kwargs):
16         allowed_kwargs = {'name', 'logging', 'model_size'}
17         for kwarg in kwargs.keys():
18             assert kwarg in allowed_kwargs, 'Invalid keyword argument: ' + kwarg
19         name = kwargs.get('name')
20         if not name:
21             layer = self.__class__.__name__.lower() # "layer"
22             name = layer + '_' + str(get_layer_uid(layer))
23         self.name = name
24         self.vars = {}
25         logging = kwargs.get('logging', False)
26         self.logging = logging
27         self.sparse_inputs = False
28 
29     def _call(self, inputs):
30         return inputs
31 
32     def __call__(self, inputs):
33         with tf.name_scope(self.name):
34             if self.logging and not self.sparse_inputs:
35                 tf.summary.histogram(self.name + '/inputs', inputs)
36             outputs = self._call(inputs)
37             if self.logging:
38                 tf.summary.histogram(self.name + '/outputs', outputs)
39             return outputs
40 
41     def _log_vars(self):
42         for var in self.vars:
43             tf.summary.histogram(self.name + '/vars/' + var, self.vars[var])
View Code

方法:

__init__(): 获取传入的name, logging, model_size参数。初始化实例变量name, vars{}, logging, sparse_inputs

_call(inputs): 定义层的计算图:获取input, 返回output.

__call__(inputs): 相当于_call()的装饰器,在实现列_call()基本功能后,丰富了其功能,这里主要通过tf.summary.histogram() 可以查看inputs与outputs分布情况的直方图。

_log_vars(): 记录所有变量。实现时主要将vars中的各个变量以直方图形式显示。

3. class Dense

Dense layer主要用于实现全连接层的基本功能。即为了最终得到 Relu(Wx + b)。

__init__(): 用于获取初始化成员变量。其中num_features_nonzero和featureless的作用目前还不清楚。

_call(): 用于实现并且返回Relu(Wx + b)

 1 class Dense(Layer):
 2     """Dense layer."""
 3 
 4     def __init__(self, input_dim, output_dim, dropout=0.,
 5                  act=tf.nn.relu, placeholders=None, bias=True, featureless=False,
 6                  sparse_inputs=False, **kwargs):
 7         super(Dense, self).__init__(**kwargs)
 8 
 9         self.dropout = dropout
10 
11         self.act = act
12         self.featureless = featureless
13         self.bias = bias
14         self.input_dim = input_dim
15         self.output_dim = output_dim
16 
17         # helper variable for sparse dropout
18         self.sparse_inputs = sparse_inputs
19         if sparse_inputs:
20             self.num_features_nonzero = placeholders['num_features_nonzero']
21 
22         with tf.variable_scope(self.name + '_vars'):
23             self.vars['weights'] = tf.get_variable('weights', shape=(input_dim, output_dim),
24         dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer(),                                             
25         regularizer=tf.contrib.layers.l2_regularizer(FLAGS.weight_decay))
26             if self.bias:
27                 self.vars['bias'] = zeros([output_dim], name='bias')
28 
29         if self.logging:
30             self._log_vars()
31 
32     def _call(self, inputs):
33         x = inputs
34         x = tf.nn.dropout(x, 1 - self.dropout)
35 
36         # transform
37         output = tf.matmul(x, self.vars['weights'])
38 
39         # bias
40         if self.bias:
41             output += self.vars['bias']
42 
43         return self.act(output)
View Code

原文地址:https://www.cnblogs.com/shiyublog/p/9894617.html