contrastive loss

昨天Ke Kiaming大神的新文章 MoCo问世,reID中的contrastive loss逐渐往上游影响。自己对这一块一直没有一个总结梳理,趁着学习这篇文章的机会整理一下,挖个坑慢慢填


Distance metric learning aims to learn an embedding representation of the data that preserves the distance between similar data points close and dissimilar data points far on the embedding space1

1. Improved Deep Metric Learning with Multi-class N-pair Loss Objective [NIP2016] [pdf] [code-pytorch]

  • 定义: 

    • 样本数据$x in mathcal{X}$,标签为$y in {1,2,...,L}$,$x_+,x_-$分别表述输入样本的正负样本对(同类/不同类);$f(.; heta): mathcal{X} ightarrow mathbb{R}^K$表示feature embedding,$f(x)$是feature embedding vector;$m$是margin
    • Contrastive loss takes pairs of examples as input and trains a network to predict whether two inputs are from the same class or not: $$mathcal{L}^{m}_{cont}(x_i,x_j;f)=mathbb{1}{y_i=y_j}||f_i-f_j||_2^2+mathbb{1}{y_i e y_j}max(0,m-||f_i-f_j||_2^2)^2$$
    • Triplet loss shares a similar spirit to contrastive loss, but is composed of triplets, each consisting of a query, a positive example (to the query), and a negative example:$$mathcal{L}^{m}_{cont}(x,x_+,x_-;f)=max(0, ||f-f_+||_2^2+||f-f_-||_2^2)^2+m$$

  • 方法

    • 针对之前的相关方法每次只有一个负样本训练存在的收敛慢、局部最优(可通过hard negative mining解决)的问题,提出了multi-class N-pair loss和negative class mining
    • multi-class N-pair loss
      • 每个batch选$N$个class,每个class选一对样本,即${(x_1, x_1^+),...,(x_N, x_N^+)}$,建立N个tuplets: ${S_i}_{i=1}^N$,其中$S_i={x_i, x_1^+, x_2^+, ..., x_N^+}$构成了一对正样本和N-1对负样本
      • 损失函数$$mathcal{L}_{N-pair-mc}({(x_i,x_i^+)}_{i=1}^N;f)= frac{1}{N} sum_{i=1}^N log(1+sum_{j e i} exp(f_i^T f_j^+ - f_i^T f_i^+))$$
      • 而$$log(1+sum_{i=1}^{L-1} exp(f^T f_i - f^T f^+))=-log frac{exp(f^T f^+)}{exp(f^T f^+)+ sum_{i=1}^{L-1}exp(f^T f_i )}$$ 和softmax非常像!
    • negative class mining
      1. Evaluate Embedding Vectors: choose randomly a large number of output classes C; for each class, randomly pass a few (one or two) examples to extract their embedding vectors.
      2. Select Negative Classes: select one class randomly from C classes from step 1. Next, greedily add a new class that violates triplet constraint the most w.r.t. the selected classes till we reach N classes. When a tie appears, we randomly pick one of tied classes. 这一步不是很懂,大概是每次从剩下类中找最hard的类的sample,直到N?
      3. Finalize N-pair: draw two examples from each selected class from step 2.

2. Unsupervised Feature Learning via Non-Parametric Instance Discrimination [pdf] [code-pytorch]

3. Momentum Contrast for Unsupervised Visual Representation Learning

原文地址:https://www.cnblogs.com/xiaoaoran/p/11876754.html