Tensorflow模型保存时程序意外被终止导致模型参数数据损坏且加载模型失败

Tensorflow模型保存时程序意外被终止导致模型参数数据损坏且加载模型失败

在采用Tensorflow训练并保存模型时,由于断电、系统死机等突发原因导致正在保存模型的程序被终止,在checkpoint保存的目录中会出现诸如xxx.tempstate的文件。
当加载模型准备恢复session时,会报错:checksum failed. 这就是因为md5加密得到的code和受损的checkpoint文件(一共3个)计算得到的md5码不符,这就是无可恢复的模型损坏。
为了避免该情形的出现,需要使用global_step结合max_to_keep两个设置来设置模型备份,避免只保存一个模型导致的高风险。
其中,global_step必须是一个自增的变量,它是tensorflow构建的图中的一个全局的tensor,每次sess.run(opt)的时候都需要自增tf.assign_add(global_step, 1),初始化为initializer=0即可。
其中,max_to_keep是在创建tf.train.Saver时设置的,目的是避免保存过多的checkpoint文件,该值确保checkpoint保存目录下最多只有max_to_keep数目的模型文件。
示例代码如下:

# iterative_inference.py # NN inference in an iterative manner, instead of a forward single shot.

import numpy as np
import os
import platform
import matplotlib.pyplot as plt

import dataset
import components.utils as utils

import tensorflow as tf

def get_conv_weights(w, h, chn_in, chn_out):
dim = [w, h, chn_in, chn_out]
init_op = tf.truncated_normal(dim, 0.02)
return tf.get_variable(
name='weights',
initializer=init_op)

def get_fc_weights(chn_in, chn_out):
dim = [chn_in, chn_out]
init_op = tf.truncated_normal(dim, 0.02)
return tf.get_variable(
name='weights',
initializer=init_op)

def get_bias(filters):
init_op = tf.zeros([filters], dtype=tf.float32)
return tf.get_variable(
name='bias',
initializer=init_op)

def get_nonlinear_layer(inputs):
return tf.nn.leaky_relu(inputs, alpha=0.2)

def get_conv_layer(inputs, kernel_size, strides, filters):
w = kernel_size[0]
h = kernel_size[1]
chn_in = inputs.shape.as_list()[-1]
chn_out = filters
weights = get_conv_weights(w, h, chn_in, chn_out)
bias = get_bias(chn_out)
layer = tf.nn.conv2d(inputs, weights, strides, padding='SAME')
layer = tf.nn.bias_add(layer, bias)
return layer

def get_fc_layer(inputs, units):
chn_in = inputs.shape.as_list()[-1]
chn_out = units
weights = get_fc_weights(chn_in, chn_out)
bias = get_bias(chn_out)
layer = tf.matmul(inputs, weights)
layer = tf.nn.bias_add(layer, bias)
return layer

def get_controlled_layer(inputs, control): # define your own control strategy
return tf.nn.bias_add(inputs, control)

def get_loss(outputs, feedbacks):
return tf.nn.softmax_cross_entropy_with_logits_v2(None, feedbacks, outputs)

def convert_tensor_conv2fc(tensor): # issue: use max or mean for pooling?
return tf.reduce_mean(tensor, axis=[1, 2])

class IINN(object):
def init(self, dim_x, dim_y,
conv_config, fc_config, att_config):
self.inputs = tf.placeholder(shape=dim_x, dtype=tf.float32)
self.feedbacks = tf.placeholder(shape=dim_y, dtype=tf.float32)

    </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">rec_layers</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc10">[]</span><span class="sc0">
    </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">rec_layers</span><span class="sc10">.</span><span class="sc11">append</span><span class="sc10">(</span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">inputs</span><span class="sc10">)</span><span class="sc0">

    </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">att_layers</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc10">[]</span><span class="sc0">
    </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">att_layers</span><span class="sc10">.</span><span class="sc11">append</span><span class="sc10">(</span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">feedbacks</span><span class="sc10">)</span><span class="sc0">

    </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">ctl_layers</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc10">[]</span><span class="sc0">

    </span><span class="sc1"># the optimizer</span><span class="sc0">
    </span><span class="sc1"># Learning rate stages: 1E-3, 1E-4, 1E-5.</span><span class="sc0">
    </span><span class="sc1"># On CIFAR-10, it converged on 0.4 (cross entrophy)</span><span class="sc0">
    </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">optimzer</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">train</span><span class="sc10">.</span><span class="sc11">AdamOptimizer</span><span class="sc10">(</span><span class="sc11">learning_rate</span><span class="sc10">=</span><span class="sc2">1E-4</span><span class="sc10">)</span><span class="sc0">

    </span><span class="sc11">scope</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc4">'attention'</span><span class="sc0">
    </span><span class="sc5">with</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">variable_scope</span><span class="sc10">(</span><span class="sc11">scope</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">reuse</span><span class="sc10">=</span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">AUTO_REUSE</span><span class="sc10">):</span><span class="sc0">
        </span><span class="sc1"># attention module</span><span class="sc0">
        </span><span class="sc11">sub_scope</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc4">'fc_%d'</span><span class="sc0">
        </span><span class="sc5">for</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc0"> </span><span class="sc5">in</span><span class="sc0"> </span><span class="sc11">range</span><span class="sc10">(</span><span class="sc11">len</span><span class="sc10">(</span><span class="sc11">att_config</span><span class="sc10">)):</span><span class="sc0">
            </span><span class="sc5">with</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">variable_scope</span><span class="sc10">(</span><span class="sc11">sub_scope</span><span class="sc0"> </span><span class="sc10">%</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">reuse</span><span class="sc10">=</span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">AUTO_REUSE</span><span class="sc10">):</span><span class="sc0">
                </span><span class="sc11">fc_</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">get_fc_layer</span><span class="sc10">(</span><span class="sc0">
                    </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">att_layers</span><span class="sc10">[-</span><span class="sc2">1</span><span class="sc10">],</span><span class="sc0">
                    </span><span class="sc11">att_config</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">][</span><span class="sc4">'units'</span><span class="sc10">])</span><span class="sc0">
                </span><span class="sc11">fc_</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">get_nonlinear_layer</span><span class="sc10">(</span><span class="sc11">fc_</span><span class="sc10">)</span><span class="sc0">
                </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">att_layers</span><span class="sc10">.</span><span class="sc11">append</span><span class="sc10">(</span><span class="sc11">fc_</span><span class="sc10">)</span><span class="sc0">
        </span><span class="sc1"># bridge tensor between attention to biases of conv</span><span class="sc0">
        </span><span class="sc11">num_biases</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc2">0</span><span class="sc0">
        </span><span class="sc5">for</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc0"> </span><span class="sc5">in</span><span class="sc0"> </span><span class="sc11">range</span><span class="sc10">(</span><span class="sc11">len</span><span class="sc10">(</span><span class="sc11">conv_config</span><span class="sc10">)):</span><span class="sc0">
            </span><span class="sc11">num_biases</span><span class="sc0"> </span><span class="sc10">+=</span><span class="sc0"> </span><span class="sc11">conv_config</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">][</span><span class="sc4">'filters'</span><span class="sc10">]</span><span class="sc0">
        </span><span class="sc5">with</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">variable_scope</span><span class="sc10">(</span><span class="sc11">sub_scope</span><span class="sc0"> </span><span class="sc10">%</span><span class="sc0"> </span><span class="sc11">len</span><span class="sc10">(</span><span class="sc11">att_config</span><span class="sc10">)):</span><span class="sc0">
            </span><span class="sc11">fc_</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">get_fc_layer</span><span class="sc10">(</span><span class="sc0">
                </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">att_layers</span><span class="sc10">[-</span><span class="sc2">1</span><span class="sc10">],</span><span class="sc0">
                </span><span class="sc11">num_biases</span><span class="sc10">)</span><span class="sc0">
            </span><span class="sc5">assert</span><span class="sc0"> </span><span class="sc11">fc_</span><span class="sc10">.</span><span class="sc11">shape</span><span class="sc10">.</span><span class="sc11">as_list</span><span class="sc10">()[</span><span class="sc2">0</span><span class="sc10">]</span><span class="sc0"> </span><span class="sc10">==</span><span class="sc0"> </span><span class="sc2">1</span><span class="sc0">
            </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">att_layers</span><span class="sc10">.</span><span class="sc11">append</span><span class="sc10">(</span><span class="sc11">fc_</span><span class="sc10">[</span><span class="sc2">0</span><span class="sc10">])</span><span class="sc0">

    </span><span class="sc11">scope</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc4">'control'</span><span class="sc0">
    </span><span class="sc5">with</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">variable_scope</span><span class="sc10">(</span><span class="sc11">scope</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">reuse</span><span class="sc10">=</span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">AUTO_REUSE</span><span class="sc10">):</span><span class="sc0">
        </span><span class="sc1"># creating sub operations with back-prop killed</span><span class="sc0">
        </span><span class="sc11">conv_bias_ctl</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc10">[]</span><span class="sc0">
        </span><span class="sc11">offset</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc2">0</span><span class="sc0">
        </span><span class="sc5">for</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc0"> </span><span class="sc5">in</span><span class="sc0"> </span><span class="sc11">range</span><span class="sc10">(</span><span class="sc11">len</span><span class="sc10">(</span><span class="sc11">conv_config</span><span class="sc10">)):</span><span class="sc0">
            </span><span class="sc11">ctl_grad_free</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> 
                </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">att_layers</span><span class="sc10">[-</span><span class="sc2">1</span><span class="sc10">][</span><span class="sc11">offset</span><span class="sc10">:</span><span class="sc11">offset</span><span class="sc0"> </span><span class="sc10">+</span><span class="sc0"> </span><span class="sc11">conv_config</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">][</span><span class="sc4">'filters'</span><span class="sc10">]]</span><span class="sc0">
            </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">ctl_layers</span><span class="sc10">.</span><span class="sc11">append</span><span class="sc10">(</span><span class="sc11">ctl_grad_free</span><span class="sc10">)</span><span class="sc0">
            </span><span class="sc5">assert</span><span class="sc0"> </span><span class="sc11">conv_config</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">][</span><span class="sc4">'filters'</span><span class="sc10">]</span><span class="sc0"> </span><span class="sc10">==</span><span class="sc0"> </span><span class="sc11">ctl_grad_free</span><span class="sc10">.</span><span class="sc11">shape</span><span class="sc10">.</span><span class="sc11">as_list</span><span class="sc10">()[</span><span class="sc2">0</span><span class="sc10">]</span><span class="sc0">
            </span><span class="sc11">offset</span><span class="sc0"> </span><span class="sc10">+=</span><span class="sc0"> </span><span class="sc11">ctl_grad_free</span><span class="sc10">.</span><span class="sc11">shape</span><span class="sc10">.</span><span class="sc11">as_list</span><span class="sc10">()[</span><span class="sc2">0</span><span class="sc10">]</span><span class="sc0">

    </span><span class="sc11">scope</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc4">'recognition'</span><span class="sc0">
    </span><span class="sc5">with</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">variable_scope</span><span class="sc10">(</span><span class="sc11">scope</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">reuse</span><span class="sc10">=</span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">AUTO_REUSE</span><span class="sc10">):</span><span class="sc0">
        </span><span class="sc11">sub_scope</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc4">'conv_%d'</span><span class="sc0">
        </span><span class="sc5">for</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc0"> </span><span class="sc5">in</span><span class="sc0"> </span><span class="sc11">range</span><span class="sc10">(</span><span class="sc11">len</span><span class="sc10">(</span><span class="sc11">conv_config</span><span class="sc10">)):</span><span class="sc0">
            </span><span class="sc5">with</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">variable_scope</span><span class="sc10">(</span><span class="sc11">sub_scope</span><span class="sc0"> </span><span class="sc10">%</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc10">):</span><span class="sc0">
                </span><span class="sc11">conv_</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">get_conv_layer</span><span class="sc10">(</span><span class="sc0">
                    </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">rec_layers</span><span class="sc10">[-</span><span class="sc2">1</span><span class="sc10">],</span><span class="sc0">
                    </span><span class="sc11">conv_config</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">][</span><span class="sc4">'ksize'</span><span class="sc10">],</span><span class="sc0">
                    </span><span class="sc11">conv_config</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">][</span><span class="sc4">'strides'</span><span class="sc10">],</span><span class="sc0">
                    </span><span class="sc11">conv_config</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">][</span><span class="sc4">'filters'</span><span class="sc10">])</span><span class="sc0">
                </span><span class="sc11">conv_</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">get_controlled_layer</span><span class="sc10">(</span><span class="sc11">conv_</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">ctl_layers</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">])</span><span class="sc0">
                </span><span class="sc11">conv_</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">get_nonlinear_layer</span><span class="sc10">(</span><span class="sc11">conv_</span><span class="sc10">)</span><span class="sc0">
                </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">rec_layers</span><span class="sc10">.</span><span class="sc11">append</span><span class="sc10">(</span><span class="sc11">conv_</span><span class="sc10">)</span><span class="sc0">
        </span><span class="sc1"># bridge tensor between conv and fc to let it flow thru</span><span class="sc0">
        </span><span class="sc11">layer</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">convert_tensor_conv2fc</span><span class="sc10">(</span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">rec_layers</span><span class="sc10">[-</span><span class="sc2">1</span><span class="sc10">])</span><span class="sc0">
        </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">rec_layers</span><span class="sc10">.</span><span class="sc11">append</span><span class="sc10">(</span><span class="sc11">layer</span><span class="sc10">)</span><span class="sc0">

        </span><span class="sc1"># creating classifier using fc layers</span><span class="sc0">
        </span><span class="sc11">sub_scope</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc4">'fc_%d'</span><span class="sc0">
        </span><span class="sc5">for</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc0"> </span><span class="sc5">in</span><span class="sc0"> </span><span class="sc11">range</span><span class="sc10">(</span><span class="sc11">len</span><span class="sc10">(</span><span class="sc11">fc_config</span><span class="sc10">)):</span><span class="sc0">
            </span><span class="sc5">with</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">variable_scope</span><span class="sc10">(</span><span class="sc11">sub_scope</span><span class="sc0"> </span><span class="sc10">%</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc10">):</span><span class="sc0">
                </span><span class="sc11">fc_</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">get_fc_layer</span><span class="sc10">(</span><span class="sc0">
                    </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">rec_layers</span><span class="sc10">[-</span><span class="sc2">1</span><span class="sc10">],</span><span class="sc0">
                    </span><span class="sc11">fc_config</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">][</span><span class="sc4">'units'</span><span class="sc10">])</span><span class="sc0">
                </span><span class="sc11">fc_</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">get_nonlinear_layer</span><span class="sc10">(</span><span class="sc11">fc_</span><span class="sc10">)</span><span class="sc0">
                </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">rec_layers</span><span class="sc10">.</span><span class="sc11">append</span><span class="sc10">(</span><span class="sc11">fc_</span><span class="sc10">)</span><span class="sc0">
        </span><span class="sc1"># the last classifier layer -- using fc without nonlinearization</span><span class="sc0">
        </span><span class="sc5">with</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">variable_scope</span><span class="sc10">(</span><span class="sc11">sub_scope</span><span class="sc0"> </span><span class="sc10">%</span><span class="sc0"> </span><span class="sc11">len</span><span class="sc10">(</span><span class="sc11">fc_config</span><span class="sc10">)):</span><span class="sc0">
            </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">outputs</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">get_fc_layer</span><span class="sc10">(</span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">rec_layers</span><span class="sc10">[-</span><span class="sc2">1</span><span class="sc10">],</span><span class="sc0"> </span><span class="sc11">dim_y</span><span class="sc10">[</span><span class="sc2">1</span><span class="sc10">])</span><span class="sc0">
        </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">rec_layers</span><span class="sc10">.</span><span class="sc11">append</span><span class="sc10">(</span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">outputs</span><span class="sc10">)</span><span class="sc0">

        </span><span class="sc1"># calculate the loss</span><span class="sc0">
        </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">rec_loss</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">get_loss</span><span class="sc10">(</span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">outputs</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">feedbacks</span><span class="sc10">)</span><span class="sc0">

    </span><span class="sc1"># Creating minimizers for different training purpose</span><span class="sc0">
    </span><span class="sc1"># group the variables by its namespace</span><span class="sc0">
    </span><span class="sc11">vars</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">global_variables</span><span class="sc10">()</span><span class="sc0">
    </span><span class="sc11">rec_vars</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc10">[]</span><span class="sc0">
    </span><span class="sc11">att_vars</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc10">[]</span><span class="sc0">
    </span><span class="sc5">for</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc0"> </span><span class="sc5">in</span><span class="sc0"> </span><span class="sc11">range</span><span class="sc10">(</span><span class="sc11">len</span><span class="sc10">(</span><span class="sc11">vars</span><span class="sc10">)):</span><span class="sc0">
        </span><span class="sc5">if</span><span class="sc0"> </span><span class="sc11">vars</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">].</span><span class="sc11">name</span><span class="sc10">.</span><span class="sc11">find</span><span class="sc10">(</span><span class="sc4">'recognition'</span><span class="sc10">)</span><span class="sc0"> </span><span class="sc10">!=</span><span class="sc0"> </span><span class="sc10">-</span><span class="sc2">1</span><span class="sc10">:</span><span class="sc0">
            </span><span class="sc11">rec_vars</span><span class="sc10">.</span><span class="sc11">append</span><span class="sc10">(</span><span class="sc11">vars</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">])</span><span class="sc0">
        </span><span class="sc5">elif</span><span class="sc0"> </span><span class="sc11">vars</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">].</span><span class="sc11">name</span><span class="sc10">.</span><span class="sc11">find</span><span class="sc10">(</span><span class="sc4">'attention'</span><span class="sc10">)</span><span class="sc0"> </span><span class="sc10">!=</span><span class="sc0"> </span><span class="sc10">-</span><span class="sc2">1</span><span class="sc10">:</span><span class="sc0">
            </span><span class="sc11">att_vars</span><span class="sc10">.</span><span class="sc11">append</span><span class="sc10">(</span><span class="sc11">vars</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">])</span><span class="sc0">
        </span><span class="sc5">else</span><span class="sc10">:</span><span class="sc0">
            </span><span class="sc5">raise</span><span class="sc0"> </span><span class="sc11">NameError</span><span class="sc10">(</span><span class="sc4">'unknown variables: %s'</span><span class="sc0"> </span><span class="sc10">%</span><span class="sc0"> </span><span class="sc11">vars</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">].</span><span class="sc11">name</span><span class="sc10">)</span><span class="sc0">

    </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">minimizer_rec</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">optimzer</span><span class="sc10">.</span><span class="sc11">minimize</span><span class="sc10">(</span><span class="sc0">
        </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">rec_loss</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">var_list</span><span class="sc10">=</span><span class="sc11">rec_vars</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">name</span><span class="sc10">=</span><span class="sc4">'opt_rec'</span><span class="sc10">)</span><span class="sc0">
    </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">minimizer_att</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">optimzer</span><span class="sc10">.</span><span class="sc11">minimize</span><span class="sc10">(</span><span class="sc0">
        </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">rec_loss</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">var_list</span><span class="sc10">=</span><span class="sc11">att_vars</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">name</span><span class="sc10">=</span><span class="sc4">'opt_att'</span><span class="sc10">)</span><span class="sc0">

    </span><span class="sc1"># network self check</span><span class="sc0">
    </span><span class="sc5">print</span><span class="sc10">(</span><span class="sc3">"================================ VARIABLES ==================================="</span><span class="sc10">)</span><span class="sc0">
    </span><span class="sc11">vars</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">global_variables</span><span class="sc10">()</span><span class="sc0">
    </span><span class="sc5">for</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc0"> </span><span class="sc5">in</span><span class="sc0"> </span><span class="sc11">range</span><span class="sc10">(</span><span class="sc11">len</span><span class="sc10">(</span><span class="sc11">vars</span><span class="sc10">)):</span><span class="sc0">
        </span><span class="sc5">print</span><span class="sc10">(</span><span class="sc3">"var#%03d:%40s %16s %12s"</span><span class="sc0"> </span><span class="sc10">%</span><span class="sc0">
              </span><span class="sc10">(</span><span class="sc11">i</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">vars</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">].</span><span class="sc11">name</span><span class="sc10">[:-</span><span class="sc2">2</span><span class="sc10">],</span><span class="sc0"> </span><span class="sc11">vars</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">].</span><span class="sc11">shape</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">str</span><span class="sc10">(</span><span class="sc11">vars</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">].</span><span class="sc11">dtype</span><span class="sc10">)[</span><span class="sc2">9</span><span class="sc10">:-</span><span class="sc2">6</span><span class="sc10">]))</span><span class="sc0">
    </span><span class="sc5">print</span><span class="sc10">(</span><span class="sc3">"=============================================================================="</span><span class="sc10">)</span><span class="sc0">
    </span><span class="sc5">print</span><span class="sc10">(</span><span class="sc3">"
"</span><span class="sc10">)</span><span class="sc0">
    </span><span class="sc5">print</span><span class="sc10">(</span><span class="sc3">"================================ OPERATORS ==================================="</span><span class="sc10">)</span><span class="sc0">
    </span><span class="sc11">ops</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">rec_layers</span><span class="sc0">
    </span><span class="sc5">for</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc0"> </span><span class="sc5">in</span><span class="sc0"> </span><span class="sc11">range</span><span class="sc10">(</span><span class="sc11">len</span><span class="sc10">(</span><span class="sc11">ops</span><span class="sc10">)):</span><span class="sc0">
        </span><span class="sc5">print</span><span class="sc10">(</span><span class="sc3">"opr#%03d:%40s %16s %12s"</span><span class="sc0"> </span><span class="sc10">%</span><span class="sc0">
              </span><span class="sc10">(</span><span class="sc11">i</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">ops</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">].</span><span class="sc11">name</span><span class="sc10">[:-</span><span class="sc2">2</span><span class="sc10">],</span><span class="sc0"> </span><span class="sc11">ops</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">].</span><span class="sc11">shape</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">str</span><span class="sc10">(</span><span class="sc11">ops</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">].</span><span class="sc11">dtype</span><span class="sc10">)[</span><span class="sc2">9</span><span class="sc10">:-</span><span class="sc2">2</span><span class="sc10">]))</span><span class="sc0">
    </span><span class="sc11">ops</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">att_layers</span><span class="sc0">
    </span><span class="sc5">for</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc0"> </span><span class="sc5">in</span><span class="sc0"> </span><span class="sc11">range</span><span class="sc10">(</span><span class="sc11">len</span><span class="sc10">(</span><span class="sc11">ops</span><span class="sc10">)):</span><span class="sc0">
        </span><span class="sc5">print</span><span class="sc10">(</span><span class="sc3">"opr#%03d:%40s %16s %12s"</span><span class="sc0"> </span><span class="sc10">%</span><span class="sc0">
              </span><span class="sc10">(</span><span class="sc11">i</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">ops</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">].</span><span class="sc11">name</span><span class="sc10">[:-</span><span class="sc2">2</span><span class="sc10">],</span><span class="sc0"> </span><span class="sc11">ops</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">].</span><span class="sc11">shape</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">str</span><span class="sc10">(</span><span class="sc11">ops</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">].</span><span class="sc11">dtype</span><span class="sc10">)[</span><span class="sc2">9</span><span class="sc10">:-</span><span class="sc2">2</span><span class="sc10">]))</span><span class="sc0">
    </span><span class="sc11">ops</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">ctl_layers</span><span class="sc0">
    </span><span class="sc5">for</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc0"> </span><span class="sc5">in</span><span class="sc0"> </span><span class="sc11">range</span><span class="sc10">(</span><span class="sc11">len</span><span class="sc10">(</span><span class="sc11">ops</span><span class="sc10">)):</span><span class="sc0">
        </span><span class="sc5">print</span><span class="sc10">(</span><span class="sc3">"opr#%03d:%40s %16s %12s"</span><span class="sc0"> </span><span class="sc10">%</span><span class="sc0">
              </span><span class="sc10">(</span><span class="sc11">i</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">ops</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">].</span><span class="sc11">name</span><span class="sc10">[:-</span><span class="sc2">2</span><span class="sc10">],</span><span class="sc0"> </span><span class="sc11">ops</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">].</span><span class="sc11">shape</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">str</span><span class="sc10">(</span><span class="sc11">ops</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">].</span><span class="sc11">dtype</span><span class="sc10">)[</span><span class="sc2">9</span><span class="sc10">:-</span><span class="sc2">2</span><span class="sc10">]))</span><span class="sc0">
    </span><span class="sc5">print</span><span class="sc10">(</span><span class="sc3">"=============================================================================="</span><span class="sc10">)</span><span class="sc0">

</span><span class="sc5">def</span><span class="sc0"> </span><span class="sc9">attention</span><span class="sc10">(</span><span class="sc11">self</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">x</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">y</span><span class="sc10">):</span><span class="sc0">
    </span><span class="sc5">pass</span><span class="sc0">
</span><span class="sc5">def</span><span class="sc0"> </span><span class="sc9">inference</span><span class="sc10">(</span><span class="sc11">self</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">x</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">a</span><span class="sc10">):</span><span class="sc0">
    </span><span class="sc5">pass</span><span class="sc0">
</span><span class="sc5">def</span><span class="sc0"> </span><span class="sc9">getInputPlaceHolder</span><span class="sc10">(</span><span class="sc11">self</span><span class="sc10">):</span><span class="sc0">
    </span><span class="sc5">return</span><span class="sc0"> </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">inputs</span><span class="sc0">
</span><span class="sc5">def</span><span class="sc0"> </span><span class="sc9">getFeedbackPlaceHolder</span><span class="sc10">(</span><span class="sc11">self</span><span class="sc10">):</span><span class="sc0">
    </span><span class="sc5">return</span><span class="sc0"> </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">feedbacks</span><span class="sc0">
</span><span class="sc5">def</span><span class="sc0"> </span><span class="sc9">getOutputTensor</span><span class="sc10">(</span><span class="sc11">self</span><span class="sc10">):</span><span class="sc0">
    </span><span class="sc5">return</span><span class="sc0"> </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">outputs</span><span class="sc0">
</span><span class="sc5">def</span><span class="sc0"> </span><span class="sc9">getControlTensors</span><span class="sc10">(</span><span class="sc11">self</span><span class="sc10">):</span><span class="sc0">
    </span><span class="sc5">return</span><span class="sc0"> </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">ctl_layers</span><span class="sc0">
</span><span class="sc5">def</span><span class="sc0"> </span><span class="sc9">getLoss</span><span class="sc10">(</span><span class="sc11">self</span><span class="sc10">):</span><span class="sc0">
    </span><span class="sc5">return</span><span class="sc0"> </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">rec_loss</span><span class="sc0">
</span><span class="sc5">def</span><span class="sc0"> </span><span class="sc9">getOptRec</span><span class="sc10">(</span><span class="sc11">self</span><span class="sc10">):</span><span class="sc0">
    </span><span class="sc5">return</span><span class="sc0"> </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">minimizer_rec</span><span class="sc0">
</span><span class="sc5">def</span><span class="sc0"> </span><span class="sc9">getOptAtt</span><span class="sc10">(</span><span class="sc11">self</span><span class="sc10">):</span><span class="sc0">
    </span><span class="sc5">return</span><span class="sc0"> </span><span class="sc11">self</span><span class="sc10">.</span><span class="sc11">minimizer_att</span><span class="sc0">

def new_conv_config(k_w, k_h, s_w, s_h, filters):
demo_config = dict()
demo_config['ksize'] = (k_w, k_h)
demo_config['strides'] = (1, s_w, s_h, 1)
demo_config['filters'] = filters
return demo_config

def new_fc_config(units):
demo_config = dict()
demo_config['units'] = units
return demo_config

def Build_IINN(n_class):
dim_x = [1, None, None, 3]
dim_y = [1, n_class]

</span><span class="sc1"># configure the convolution layers</span><span class="sc0">
</span><span class="sc11">n_conv</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc2">4</span><span class="sc0">
</span><span class="sc11">conv_config</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc10">[</span><span class="sc5">None</span><span class="sc10">]</span><span class="sc0"> </span><span class="sc10">*</span><span class="sc0"> </span><span class="sc11">n_conv</span><span class="sc0">
</span><span class="sc5">for</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc0"> </span><span class="sc5">in</span><span class="sc0"> </span><span class="sc11">range</span><span class="sc10">(</span><span class="sc11">n_conv</span><span class="sc10">):</span><span class="sc0">
    </span><span class="sc11">conv_config</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">]</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">new_conv_config</span><span class="sc10">(</span><span class="sc2">3</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc2">3</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc2">2</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc2">2</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc2">8</span><span class="sc0"> </span><span class="sc10">&lt;&lt;</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc10">)</span><span class="sc0">

</span><span class="sc1"># configure the fully connectied layers</span><span class="sc0">
</span><span class="sc11">n_fc</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc2">3</span><span class="sc0">
</span><span class="sc11">fc_config</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc10">[</span><span class="sc5">None</span><span class="sc10">]</span><span class="sc0"> </span><span class="sc10">*</span><span class="sc0"> </span><span class="sc11">n_fc</span><span class="sc0">
</span><span class="sc5">for</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc0"> </span><span class="sc5">in</span><span class="sc0"> </span><span class="sc11">range</span><span class="sc10">(</span><span class="sc11">n_fc</span><span class="sc10">):</span><span class="sc0">
    </span><span class="sc11">fc_config</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">]</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">new_fc_config</span><span class="sc10">(</span><span class="sc2">16</span><span class="sc0"> </span><span class="sc10">&lt;&lt;</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc10">)</span><span class="sc0">

</span><span class="sc1"># configure the special module : feedback attention</span><span class="sc0">
</span><span class="sc11">n_att</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc2">3</span><span class="sc0">
</span><span class="sc11">att_config</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc10">[</span><span class="sc5">None</span><span class="sc10">]</span><span class="sc0"> </span><span class="sc10">*</span><span class="sc0"> </span><span class="sc11">n_att</span><span class="sc0">
</span><span class="sc5">for</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc0"> </span><span class="sc5">in</span><span class="sc0"> </span><span class="sc11">range</span><span class="sc10">(</span><span class="sc11">n_att</span><span class="sc10">):</span><span class="sc0">
    </span><span class="sc11">att_config</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">]</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">new_fc_config</span><span class="sc10">(</span><span class="sc2">64</span><span class="sc0"> </span><span class="sc10">&gt;&gt;</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc10">)</span><span class="sc0">

</span><span class="sc5">return</span><span class="sc0"> </span><span class="sc11">IINN</span><span class="sc10">(</span><span class="sc11">dim_x</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">dim_y</span><span class="sc10">,</span><span class="sc0">
            </span><span class="sc11">conv_config</span><span class="sc10">,</span><span class="sc0">
            </span><span class="sc11">fc_config</span><span class="sc10">,</span><span class="sc0">
            </span><span class="sc11">att_config</span><span class="sc10">)</span><span class="sc0">

def Train_IINN(iinn_: IINN, data: dict, model_path: str) -> float:
xx = data['input']
yy = data['output']

</span><span class="sc11">x_t</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">iinn_</span><span class="sc10">.</span><span class="sc11">getInputPlaceHolder</span><span class="sc10">()</span><span class="sc0"> </span><span class="sc1"># tensor of inputs</span><span class="sc0">
</span><span class="sc11">y_t</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">iinn_</span><span class="sc10">.</span><span class="sc11">getOutputTensor</span><span class="sc10">()</span><span class="sc0"> </span><span class="sc1"># tensor of outputs</span><span class="sc0">
</span><span class="sc11">c_t</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">iinn_</span><span class="sc10">.</span><span class="sc11">getControlTensors</span><span class="sc10">()</span><span class="sc0"> </span><span class="sc1"># tensor of all control signals</span><span class="sc0">
</span><span class="sc11">f_t</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">iinn_</span><span class="sc10">.</span><span class="sc11">getFeedbackPlaceHolder</span><span class="sc10">()</span><span class="sc0"> </span><span class="sc1"># tensor of feedback</span><span class="sc0">

</span><span class="sc11">loss_t</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">iinn_</span><span class="sc10">.</span><span class="sc11">getLoss</span><span class="sc10">()</span><span class="sc0">
</span><span class="sc11">opt_rec</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">iinn_</span><span class="sc10">.</span><span class="sc11">getOptRec</span><span class="sc10">()</span><span class="sc0">
</span><span class="sc11">opt_att</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">iinn_</span><span class="sc10">.</span><span class="sc11">getOptAtt</span><span class="sc10">()</span><span class="sc0">

</span><span class="sc1"># stage 1: train without attention ( a plain convolution classifier )</span><span class="sc0">
</span><span class="sc1"># set up all the control signals to 0</span><span class="sc0">
</span><span class="sc11">ctl_sig</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc10">[]</span><span class="sc0">
</span><span class="sc5">for</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc0"> </span><span class="sc5">in</span><span class="sc0"> </span><span class="sc11">range</span><span class="sc10">(</span><span class="sc11">len</span><span class="sc10">(</span><span class="sc11">c_t</span><span class="sc10">)):</span><span class="sc0">
    </span><span class="sc11">ctl_sig</span><span class="sc10">.</span><span class="sc11">append</span><span class="sc10">(</span><span class="sc11">np</span><span class="sc10">.</span><span class="sc11">array</span><span class="sc10">([</span><span class="sc2">0</span><span class="sc10">]</span><span class="sc0"> </span><span class="sc10">*</span><span class="sc0"> </span><span class="sc11">c_t</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">].</span><span class="sc11">shape</span><span class="sc10">.</span><span class="sc11">as_list</span><span class="sc10">()[</span><span class="sc2">0</span><span class="sc10">]))</span><span class="sc0">

</span><span class="sc1"># batch size should be always 1 because of control module limit</span><span class="sc0">
</span><span class="sc11">BAT_NUM</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc2">1024</span><span class="sc0">
</span><span class="sc11">MAX_ITR</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc2">100000</span><span class="sc0"> </span><span class="sc10">*</span><span class="sc0"> </span><span class="sc11">BAT_NUM</span><span class="sc0">
</span><span class="sc11">CVG_EPS</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc2">1e-2</span><span class="sc0">
</span><span class="sc11">itr</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc2">0</span><span class="sc0">
</span><span class="sc11">eps</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc2">1E10</span><span class="sc0">

</span><span class="sc1"># set up the global step counter</span><span class="sc0">
</span><span class="sc11">global_step</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">get_variable</span><span class="sc10">(</span><span class="sc11">name</span><span class="sc10">=</span><span class="sc3">"global_step"</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">initializer</span><span class="sc10">=</span><span class="sc2">0</span><span class="sc10">)</span><span class="sc0">
</span><span class="sc11">step_next</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">assign_add</span><span class="sc10">(</span><span class="sc11">global_step</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc2">1</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">use_locking</span><span class="sc10">=</span><span class="sc5">True</span><span class="sc10">)</span><span class="sc0">

</span><span class="sc1"># establish the training context</span><span class="sc0">
</span><span class="sc11">sess</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">Session</span><span class="sc10">()</span><span class="sc0">
</span><span class="sc11">vars</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">trainable_variables</span><span class="sc10">()</span><span class="sc0">
</span><span class="sc11">saver</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">train</span><span class="sc10">.</span><span class="sc11">Saver</span><span class="sc10">(</span><span class="sc11">var_list</span><span class="sc10">=</span><span class="sc11">vars</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">max_to_keep</span><span class="sc10">=</span><span class="sc2">5</span><span class="sc10">)</span><span class="sc0">
</span><span class="sc1"># load the pretrained model if exists</span><span class="sc0">
</span><span class="sc5">if</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">train</span><span class="sc10">.</span><span class="sc11">checkpoint_exists</span><span class="sc10">(</span><span class="sc11">model_path</span><span class="sc10">):</span><span class="sc0">
    </span><span class="sc11">saver</span><span class="sc10">.</span><span class="sc11">restore</span><span class="sc10">(</span><span class="sc11">sess</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">model_path</span><span class="sc10">)</span><span class="sc0">
    </span><span class="sc11">utils</span><span class="sc10">.</span><span class="sc11">initialize_uninitialized</span><span class="sc10">(</span><span class="sc11">sess</span><span class="sc10">)</span><span class="sc0">
</span><span class="sc5">else</span><span class="sc10">:</span><span class="sc0">
    </span><span class="sc11">sess</span><span class="sc10">.</span><span class="sc11">run</span><span class="sc10">(</span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">global_variables_initializer</span><span class="sc10">())</span><span class="sc0">
</span><span class="sc1"># training loop</span><span class="sc0">
</span><span class="sc11">loss</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">np</span><span class="sc10">.</span><span class="sc11">zeros</span><span class="sc10">([</span><span class="sc11">BAT_NUM</span><span class="sc10">],</span><span class="sc0"> </span><span class="sc11">dtype</span><span class="sc10">=</span><span class="sc11">np</span><span class="sc10">.</span><span class="sc11">float32</span><span class="sc10">)</span><span class="sc0">
</span><span class="sc5">while</span><span class="sc0"> </span><span class="sc11">itr</span><span class="sc0"> </span><span class="sc10">&lt;</span><span class="sc0"> </span><span class="sc11">MAX_ITR</span><span class="sc0"> </span><span class="sc5">and</span><span class="sc0">  </span><span class="sc11">eps</span><span class="sc0"> </span><span class="sc10">&gt;</span><span class="sc0"> </span><span class="sc11">CVG_EPS</span><span class="sc10">:</span><span class="sc0">
    </span><span class="sc11">idx</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">np</span><span class="sc10">.</span><span class="sc11">random</span><span class="sc10">.</span><span class="sc11">randint</span><span class="sc10">(</span><span class="sc11">xx</span><span class="sc10">.</span><span class="sc11">shape</span><span class="sc10">[</span><span class="sc2">0</span><span class="sc10">])</span><span class="sc0">
    </span><span class="sc11">feed_in</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">dict</span><span class="sc10">()</span><span class="sc0">
    </span><span class="sc11">feed_in</span><span class="sc10">[</span><span class="sc11">x_t</span><span class="sc10">]</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">xx</span><span class="sc10">[</span><span class="sc11">idx</span><span class="sc10">:</span><span class="sc11">idx</span><span class="sc10">+</span><span class="sc2">1</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc10">:,</span><span class="sc0"> </span><span class="sc10">:,</span><span class="sc0"> </span><span class="sc10">:]</span><span class="sc0">
    </span><span class="sc11">feed_in</span><span class="sc10">[</span><span class="sc11">f_t</span><span class="sc10">]</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">yy</span><span class="sc10">[</span><span class="sc11">idx</span><span class="sc10">:</span><span class="sc11">idx</span><span class="sc10">+</span><span class="sc2">1</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc10">:]</span><span class="sc0">
    </span><span class="sc5">for</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc0"> </span><span class="sc5">in</span><span class="sc0"> </span><span class="sc11">range</span><span class="sc10">(</span><span class="sc11">len</span><span class="sc10">(</span><span class="sc11">c_t</span><span class="sc10">)):</span><span class="sc0">
        </span><span class="sc11">feed_in</span><span class="sc10">[</span><span class="sc11">c_t</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">]]</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">ctl_sig</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">]</span><span class="sc0">
    </span><span class="sc11">loss</span><span class="sc10">[</span><span class="sc11">itr</span><span class="sc0"> </span><span class="sc10">%</span><span class="sc0"> </span><span class="sc11">BAT_NUM</span><span class="sc10">],</span><span class="sc0"> </span><span class="sc11">_</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">_</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> 
        </span><span class="sc11">sess</span><span class="sc10">.</span><span class="sc11">run</span><span class="sc10">([</span><span class="sc11">loss_t</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">opt_rec</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">step_next</span><span class="sc10">],</span><span class="sc0"> </span><span class="sc11">feed_dict</span><span class="sc10">=</span><span class="sc11">feed_in</span><span class="sc10">)</span><span class="sc0">
    </span><span class="sc11">itr</span><span class="sc0"> </span><span class="sc10">+=</span><span class="sc0"> </span><span class="sc2">1</span><span class="sc0">
    </span><span class="sc5">if</span><span class="sc0"> </span><span class="sc11">itr</span><span class="sc0"> </span><span class="sc10">%</span><span class="sc0"> </span><span class="sc11">BAT_NUM</span><span class="sc0"> </span><span class="sc10">==</span><span class="sc0"> </span><span class="sc2">0</span><span class="sc10">:</span><span class="sc0">
        </span><span class="sc11">eps</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">np</span><span class="sc10">.</span><span class="sc11">mean</span><span class="sc10">(</span><span class="sc11">loss</span><span class="sc10">)</span><span class="sc0">
        </span><span class="sc5">print</span><span class="sc10">(</span><span class="sc3">"batch#%05d loss=%3.5f"</span><span class="sc0"> </span><span class="sc10">%</span><span class="sc0"> </span><span class="sc10">(</span><span class="sc11">itr</span><span class="sc0"> </span><span class="sc10">/</span><span class="sc0"> </span><span class="sc11">BAT_NUM</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">eps</span><span class="sc10">))</span><span class="sc0">
    </span><span class="sc5">if</span><span class="sc0"> </span><span class="sc11">itr</span><span class="sc0"> </span><span class="sc10">%</span><span class="sc0"> </span><span class="sc10">(</span><span class="sc11">BAT_NUM</span><span class="sc0"> </span><span class="sc10">*</span><span class="sc0"> </span><span class="sc2">16</span><span class="sc10">)</span><span class="sc0"> </span><span class="sc10">==</span><span class="sc0"> </span><span class="sc2">0</span><span class="sc10">:</span><span class="sc0">
        </span><span class="sc11">saver</span><span class="sc10">.</span><span class="sc11">save</span><span class="sc10">(</span><span class="sc11">sess</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">model_path</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">global_step</span><span class="sc10">=</span><span class="sc11">global_step</span><span class="sc10">)</span><span class="sc0">
</span><span class="sc5">return</span><span class="sc0"> </span><span class="sc11">eps</span><span class="sc0">

def Test_IINN(iinn_: IINN, data: dict, model_path: str) -> float:
xx = data['input']
yy = data['output']

</span><span class="sc11">x_t</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">iinn_</span><span class="sc10">.</span><span class="sc11">getInputPlaceHolder</span><span class="sc10">()</span><span class="sc0">  </span><span class="sc1"># tensor of inputs</span><span class="sc0">
</span><span class="sc11">y_t</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">iinn_</span><span class="sc10">.</span><span class="sc11">getOutputTensor</span><span class="sc10">()</span><span class="sc0">  </span><span class="sc1"># tensor of outputs</span><span class="sc0">
</span><span class="sc11">c_t</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">iinn_</span><span class="sc10">.</span><span class="sc11">getControlTensors</span><span class="sc10">()</span><span class="sc0">  </span><span class="sc1"># tensor of all control signals</span><span class="sc0">

</span><span class="sc1"># set up all the control signals to 0</span><span class="sc0">
</span><span class="sc11">ctl_sig</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc10">[]</span><span class="sc0">
</span><span class="sc5">for</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc0"> </span><span class="sc5">in</span><span class="sc0"> </span><span class="sc11">range</span><span class="sc10">(</span><span class="sc11">len</span><span class="sc10">(</span><span class="sc11">c_t</span><span class="sc10">)):</span><span class="sc0">
    </span><span class="sc11">ctl_sig</span><span class="sc10">.</span><span class="sc11">append</span><span class="sc10">(</span><span class="sc11">np</span><span class="sc10">.</span><span class="sc11">array</span><span class="sc10">([</span><span class="sc2">0</span><span class="sc10">]</span><span class="sc0"> </span><span class="sc10">*</span><span class="sc0"> </span><span class="sc11">c_t</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">].</span><span class="sc11">shape</span><span class="sc10">.</span><span class="sc11">as_list</span><span class="sc10">()[</span><span class="sc2">0</span><span class="sc10">]))</span><span class="sc0">

</span><span class="sc11">sess</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">Session</span><span class="sc10">()</span><span class="sc0">
</span><span class="sc11">vars</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">trainable_variables</span><span class="sc10">()</span><span class="sc0">
</span><span class="sc11">saver</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">train</span><span class="sc10">.</span><span class="sc11">Saver</span><span class="sc10">(</span><span class="sc11">var_list</span><span class="sc10">=</span><span class="sc11">vars</span><span class="sc10">)</span><span class="sc0">
</span><span class="sc1"># load the pretrained model if exists</span><span class="sc0">
</span><span class="sc5">if</span><span class="sc0"> </span><span class="sc11">tf</span><span class="sc10">.</span><span class="sc11">train</span><span class="sc10">.</span><span class="sc11">checkpoint_exists</span><span class="sc10">(</span><span class="sc11">model_path</span><span class="sc10">):</span><span class="sc0">
    </span><span class="sc11">saver</span><span class="sc10">.</span><span class="sc11">restore</span><span class="sc10">(</span><span class="sc11">sess</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">model_path</span><span class="sc10">)</span><span class="sc0">
    </span><span class="sc1">#utils.initialize_uninitialized(sess)</span><span class="sc0">
</span><span class="sc5">else</span><span class="sc10">:</span><span class="sc0">
    </span><span class="sc5">raise</span><span class="sc0"> </span><span class="sc11">NameError</span><span class="sc10">(</span><span class="sc3">"failed to load checkpoint from path %s"</span><span class="sc0"> </span><span class="sc10">%</span><span class="sc11">model_path</span><span class="sc10">)</span><span class="sc0">

</span><span class="sc1"># inference</span><span class="sc0">
</span><span class="sc11">labels_gt</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">np</span><span class="sc10">.</span><span class="sc11">argmax</span><span class="sc10">(</span><span class="sc11">yy</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">axis</span><span class="sc10">=-</span><span class="sc2">1</span><span class="sc10">)</span><span class="sc0">
</span><span class="sc11">num_correct</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc2">0</span><span class="sc0">

</span><span class="sc5">for</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc0"> </span><span class="sc5">in</span><span class="sc0"> </span><span class="sc11">range</span><span class="sc10">(</span><span class="sc11">xx</span><span class="sc10">.</span><span class="sc11">shape</span><span class="sc10">[</span><span class="sc2">0</span><span class="sc10">]):</span><span class="sc0">
    </span><span class="sc11">feed_in</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">dict</span><span class="sc10">()</span><span class="sc0">
    </span><span class="sc11">feed_in</span><span class="sc10">[</span><span class="sc11">x_t</span><span class="sc10">]</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">xx</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">:</span><span class="sc11">i</span><span class="sc0"> </span><span class="sc10">+</span><span class="sc0"> </span><span class="sc2">1</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc10">:,</span><span class="sc0"> </span><span class="sc10">:,</span><span class="sc0"> </span><span class="sc10">:]</span><span class="sc0">
    </span><span class="sc5">for</span><span class="sc0"> </span><span class="sc11">i</span><span class="sc0"> </span><span class="sc5">in</span><span class="sc0"> </span><span class="sc11">range</span><span class="sc10">(</span><span class="sc11">len</span><span class="sc10">(</span><span class="sc11">c_t</span><span class="sc10">)):</span><span class="sc0">
        </span><span class="sc11">feed_in</span><span class="sc10">[</span><span class="sc11">c_t</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">]]</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">ctl_sig</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">]</span><span class="sc0">
    </span><span class="sc11">y</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">sess</span><span class="sc10">.</span><span class="sc11">run</span><span class="sc10">(</span><span class="sc11">y_t</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">feed_dict</span><span class="sc10">=</span><span class="sc11">feed_in</span><span class="sc10">)[</span><span class="sc2">0</span><span class="sc10">]</span><span class="sc0">
    </span><span class="sc11">label_out</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">np</span><span class="sc10">.</span><span class="sc11">argmax</span><span class="sc10">(</span><span class="sc11">y</span><span class="sc10">)</span><span class="sc0">
    </span><span class="sc5">if</span><span class="sc0"> </span><span class="sc11">label_out</span><span class="sc0"> </span><span class="sc10">==</span><span class="sc0"> </span><span class="sc11">labels_gt</span><span class="sc10">[</span><span class="sc11">i</span><span class="sc10">]:</span><span class="sc0">
        </span><span class="sc11">num_correct</span><span class="sc0"> </span><span class="sc10">+=</span><span class="sc0"> </span><span class="sc2">1</span><span class="sc0">
</span><span class="sc5">return</span><span class="sc0"> </span><span class="sc11">float</span><span class="sc10">(</span><span class="sc11">num_correct</span><span class="sc10">)</span><span class="sc0"> </span><span class="sc10">/</span><span class="sc0"> </span><span class="sc11">float</span><span class="sc10">(</span><span class="sc11">len</span><span class="sc10">(</span><span class="sc11">labels_gt</span><span class="sc10">))</span><span class="sc0">

</span><span class="sc6">''' 
# iterative inference demo
for i in range(xx.shape[0]):
    x = xx[i]
    y = yy[i]
    y_trivial = np.ones(n_class)  # start from a trivial solution
    a = iinn_.attention(x, y_trivial)
    y = iinn_.inference(x, a)
    a = iinn_.attention(x, y)
    y = iinn_.inference(x, a)
    # ... this procedure goes on and on until converged
    pass
'''</span><span class="sc0">

if name == "main":
n_class = 10
iinn_ = Build_IINN(n_class)

</span><span class="sc1"># training with CIFAR-10 dataset</span><span class="sc0">
</span><span class="sc11">data_train</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">data_test</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> 
    </span><span class="sc11">dataset</span><span class="sc10">.</span><span class="sc11">cifar10</span><span class="sc10">.</span><span class="sc11">Load_CIFAR10</span><span class="sc10">(</span><span class="sc4">'../Datasets/CIFAR10/'</span><span class="sc10">)</span><span class="sc0">
</span><span class="sc11">model_path</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc4">'../Models/CIFAR10-IINN/ckpt_iinn_cifar10'</span><span class="sc0">
</span><span class="sc11">Train_IINN</span><span class="sc10">(</span><span class="sc11">iinn_</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">data_train</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">model_path</span><span class="sc10">)</span><span class="sc0">
</span><span class="sc1"># test the trained model with test split of the same dataset</span><span class="sc0">
</span><span class="sc11">acc</span><span class="sc0"> </span><span class="sc10">=</span><span class="sc0"> </span><span class="sc11">Test_IINN</span><span class="sc10">(</span><span class="sc11">iinn_</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">data_test</span><span class="sc10">,</span><span class="sc0"> </span><span class="sc11">model_path</span><span class="sc10">)</span><span class="sc0">
</span><span class="sc5">print</span><span class="sc10">(</span><span class="sc3">"Accuracy = %6.5f"</span><span class="sc0"> </span><span class="sc10">%</span><span class="sc0"> </span><span class="sc11">acc</span><span class="sc10">)</span><span class="sc0">
</span></div></body>

其中Train_IINN函数就使用了该策略:注意不要将global_step变量引入到要保存的模型中,解决方法是在创建global_step变量前创建saver,并指定变量列表为tf.trainable_variables。
这样就只会保留前向计算所需的变量,训练的临时变量都会舍弃。

原文地址:https://www.cnblogs.com/thisisajoke/p/12033274.html