CS231n assignment3 Q4 Style Transfer

"Image Style Transfer Using Convolutional Neural Networks" (Gatys et al., CVPR 2015).
复现这一篇论文中的代码
loss由三部分组成,内容loss,风格loss,正则化loss,其中风格loss使用gram矩阵。

Content loss

def content_loss(content_weight, content_current, content_original):
    """
    Compute the content loss for style transfer.
    
    Inputs:
    - content_weight: scalar constant we multiply the content_loss by.
    - content_current: features of the current image, Tensor with shape [1, height, width, channels]
    - content_target: features of the content image, Tensor with shape [1, height, width, channels]
    
    Returns:
    - scalar content loss
    """
    # tf.squared_difference(x,y,name=None) 返回的是(x-y)(x-y)
    return content_weight * tf.reduce_sum(tf.squared_difference(content_current,content_original))

Style loss

def gram_matrix(features, normalize=True):
    """
    Compute the Gram matrix from features.
    
    Inputs:
    - features: Tensor of shape (1, H, W, C) giving features for
      a single image.
    - normalize: optional, whether to normalize the Gram matrix
        If True, divide the Gram matrix by the number of neurons (H * W * C)
    
    Returns:
    - gram: Tensor of shape (C, C) giving the (optionally normalized)
      Gram matrices for the input image.
    """
    features = tf.transpose(features,[0,3,1,2])
    shape = tf.shape(features)
    features = tf.reshape(features,(shape[0],shape[1],-1))
    transpose_features = tf.transpose(features,[0,2,1])
    result = tf.matmul(features,transpose_features)
    if normalize:
        result = tf.div(result,tf.cast(shape[0] * shape[1] * shape[2] * shape[3],tf.float32))
    return result

def style_loss(feats, style_layers, style_targets, style_weights):
    """
    Computes the style loss at a set of layers.
    
    Inputs:
    - feats: list of the features at every layer of the current image, as produced by
      the extract_features function.
    - style_layers: List of layer indices into feats giving the layers to include in the
      style loss.
    - style_targets: List of the same length as style_layers, where style_targets[i] is
      a Tensor giving the Gram matrix of the source style image computed at
      layer style_layers[i].
    - style_weights: List of the same length as style_layers, where style_weights[i]
      is a scalar giving the weight for the style loss at layer style_layers[i].
      
    Returns:
    - style_loss: A Tensor containing the scalar style loss.
    """
    # Hint: you can do this with one for loop over the style layers, and should
    # not be very much code (~5 lines). You will need to use your gram_matrix function.
    style_losses = 0
    for i in range(len(style_layers)):
        cur_index = style_layers[i]
        cur_feat = feats[cur_index]
        cur_weight = style_weights[i]
        cur_style_target = style_targets[i] #已经是一个gram矩阵了
        grammatrix = gram_matrix(cur_feat) #计算当前层的特征图的gram矩阵
        style_losses += cur_weight * tf.reduce_sum(tf.squared_difference(grammatrix,cur_style_target))
    return style_losses

Total-variation regularization

def tv_loss(img, tv_weight):
    """
    Compute total variation loss.
    
    Inputs:
    - img: Tensor of shape (1, H, W, 3) holding an input image.
    - tv_weight: Scalar giving the weight w_t to use for the TV loss.
    
    Returns:
    - loss: Tensor holding a scalar giving the total variation loss
      for img weighted by tv_weight.
    """
    # Your implementation should be vectorized and not require any loops!
    shape = tf.shape(img)
    img_row_before = tf.slice(img,[0,0,0,0],[-1,shape[1]-1,-1,-1])
    img_row_after = tf.slice(img,[0,1,0,0],[-1,shape[1]-1,-1,-1])
    img_col_before = tf.slice(img,[0,0,0,0],[-1,-1,shape[2]-1,-1])
    img_col_after = tf.slice(img,[0,0,1,0],[-1,-1,shape[2]-1,-1])
    result = tv_weight * (tf.reduce_sum(tf.squared_difference(img_row_before,img_row_after)) + 
                          tf.reduce_sum(tf.squared_difference(img_col_before,img_col_after)))
    return result

原文地址:https://www.cnblogs.com/bernieloveslife/p/10224313.html