【638】keras 多输出模型【实战】

【530】keras 实现多输出模型

[Keras] [multiple inputs / outputs] ValueError: No data provided for "xx". Need data for each key...

keras:怎样使用 fit_generator 来训练多个不同类型的输出

1. model.compile

  对于多输出模型而言,多出来一个字典的形式,通过 model.compile 里面包含的 loss、loss_weight,可以通过字典的形式设置,如下所示:

model.compile(optimizer='rmsprop',
			  # 不同输出层对应的损失函数
			  loss={'outputs1': 'binary_crossentropy',
			  		'outputs1': 'binary_crossentropy'},
			  # 不同输出层对应的损失函数权重值
			  loss_weight={'outputs1': 0.5,
			  			   'outputs1': 0.5})

  注意:字典的 key 值并不是随意设置的,需要前后一致,并且需要指定到具体的模型输出的名称以及数据生成器中的,否则是无法对应的。

2. 模型架构

  因为是单一输入就不考虑输入的名称了,输出的名称需要对应,如下所示:

# 输入
inputs = keras.Input(...)

# 模型主体部分
...

# 输出
outputs1 = layers.Conv2D(1, 3, activation="sigmoid", padding="same", name="outputs1")(x1)
outputs2 = layers.Conv2D(1, 3, activation="sigmoid", padding="same", name="outputs2")(x2)

model = keras.Model(inputs, [outputs1, outputs2])

  注意:outputs1 与 outputs2 里面的 name 值与上面对应

3. 数据生成器

  数据生成器需要生成对应格式的数据,特别是通过 key 值来对应输出数据的 labels,如下所示:

# 图像生成器,生成可以直接输入到模型中的 generator,返回值是 tuple
class ImageGenerator(keras.utils.Sequence):
    """Helper to iterate over the data (as Numpy arrays)."""
    def __init__(self, batch_size, img_size, input_img_paths, target_img_paths_louding, target_img_paths_louti):
        ...

    def __len__(self):
        return len(self.input_img_paths) // self.batch_size

    def __getitem__(self, idx):
        """Returns tuple (input, target) correspond to batch #idx."""
        
        x = np.zeros((self.batch_size,) + self.img_size + (3,), dtype="float32")
        ...
            
        y1 = np.zeros((self.batch_size,) + self.img_size + (1,), dtype="float32")
        ...
            
        y2 = np.zeros((self.batch_size,) + self.img_size + (1,), dtype="float32")
        ...
        
        # 注意 key 值的对应
        y = {'outputs1': y1, 'outputs2': y2}
        
        return x, y

  总结:实际上多输出或者多输入与单输入单输出没有实质性的区别,就是在数据处理和衔接上面容易出现问题,只要将 key 值对应,无论是 fit 还是 fit_generator 都可以实现。

原文地址:https://www.cnblogs.com/alex-bn-lee/p/15119398.html