深入学习ThreadLocal原理

  上文我们学习了ThreadLocal的基本用法以及基本原理,ThreadLocal中的方法并不多,基本用到的也就get、set、remove等方法,但是其核心逻辑还是在定义在ThreadLocal内部的静态内部类ThreadLocalMap中,里面有很多设计非常精妙的地方,本文中我们就从ThreadLocalMap的角度入手深入学习ThreadLocal的原理。

 1. 基本数据结构

  按照官方的解释是:这是一个定制化的Hash类型的map,专门用来保存线程本地变量。其内部采用是通过一个自定义的Entry来封装数据,并且保存在一个Entry数组中。为了便于处理大量且长时间存活的对象引用(其实是ThreadLocal),Entry采用WeakReference作为key的类型,当map中空间不够时,key为null的ertry将会被删除。ThreadLocalMap内部数据结构如下:

static class ThreadLocalMap {

  static class Entry extends WeakReference<ThreadLocal<?>> {
      /** 要保存到线程本地的变量 */
      Object value;

      Entry(ThreadLocal<?> k, Object v) {
          super(k);
          value = v;
      }
  }

  /**
   * 数组初始容量 -- 必须为2的倍数.
   */
  private static final int INITIAL_CAPACITY = 16;

  /**
   * 存储entry的数组,长度为2的倍数
   */
  private Entry[] table;

  /**
   * entries数量
   */
  private int size = 0;

  /**
   * resize阈值
   */
  private int threshold; // Default to 0

  /**
   * 计算阈值
   */
  private void setThreshold(int len) {
      threshold = len * 2 / 3;
  }

  /**
   * i+1,大于等于len则从0开始继续
   */
  private static int nextIndex(int i, int len) {
      return ((i + 1 < len) ? i + 1 : 0);
  }

  /**
   * i-1,小于0则从len-1开始继续
   */
  private static int prevIndex(int i, int len) {
      return ((i - 1 >= 0) ? i - 1 : len - 1);
  }

  ......

}

  在ThreadLocalMap内部通过自定义的Entry类来封装要保存的数据,以ThreadLocal类型对象为key,Object类型对象为value。这个Entry继承自WeakReference<ThreadLocal<?>>,每个Entry都可以是一个指向ThreadLocal对象的弱引用,可通过Entry的get方法来获取对ThreadLocal对象的引用,而这个引用就是key。所有的Entry统一保存在一个Entry数组table中,数组的长度必须为2的倍数,通过key的hashcode与数组长度减1进行与运算来定位Entry在数组中的存储位置,这点和hashmap类似,但是当发生hash碰撞时hashmap的处理方法是放入链表或者树中(都在同一个hash桶中),而ThreadLocalMap则是依次往后查找可以保存的地方,没有桶的概念(这点后面会结合代码详细讲)。

  

  既然ThreadLocalMap内部是一个数组,通过key的hashcode来定位到数组下标,这里我们不得不说一下key的hashcode的生成方式,非常精妙,因为key类型为ThreadLocal,所以其hashcode的生成方式也在ThreadLocal中:

  private final int threadLocalHashCode = nextHashCode();

  private static AtomicInteger nextHashCode = new AtomicInteger();

  private static final int HASH_INCREMENT = 0x61c88647;

  private static int nextHashCode() {
     return nextHashCode.getAndAdd(HASH_INCREMENT);
  }

  对于每个ThreadLocal对象,都有一个独自不变的hashcode,每新增一个ThreadLocal对象,会自动生成其自己的hashcode,其实就是让nextHashCode自增0x61c88647,目的是为了让生成的hashcode均匀的分布在2的幂次方上,而数组长度也是2的幂次方,这样就保证了要插入的元素可以均匀分布在数组中。

  虽然ThreadLocal使用了很牛逼的办法来生成hashcode,但是还是不可避免会产生hash碰撞,当出现碰撞时是如何来处理呢?我们接着看:

2. 获取元素

  我们知道ThreadLocalMap是以Entry为基本单元保存数据的,而且是以key-value对的形式,我们先来看一下是如何通过key获取到Entry的:

private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        return e;
    else
        return getEntryAfterMiss(key, i, e);
}

  这个逻辑比较简单:

  • 首先通过key的hashcode获取数组下标(与运算);
  • 如果下标对应处Entry不为空,且key与传入的key是指向同一个ThreadLocal对象则认为找到,直接返回Entry;
  • 否则执行getEntryAfterMiss;
/**
* 有三种情况下会执行这个方法
* 1. e为null;
* 2. e!=null,e的key=null;
* 3. e!=null,e的key!=null,e的key!=要找的key,即出现hash碰撞
**/
private
Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) { Entry[] tab = table; int len = tab.length; while (e != null) { ThreadLocal<?> k = e.get(); if (k == key) return e; if (k == null) expungeStaleEntry(i); else i = nextIndex(i, len); // 出现碰撞,则依次往后找 e = tab[i]; } return null; }

  这里的逻辑也比较清晰:

  • 获取内部保存Entry的数组及数组长度;
  • 获取传入Entry对应的key,如果和传入的key相等则直接返回key;
  • 如果Entry对应的key为空,则执行expungeStaleEntry,传入的参数为当前Entry所在数组下标i;
  • 否则将获取e在数组中后面那个元素并赋值给e,如果e不为空,则循环从第2步执行,否则直接退出循环;

  对于key为空的Entry在ThreadLocal里面称为staleSlot,接下来看一下expungeStaleEntry:

private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // 直接将下标为staleSlot处的元素擦除,value和Entry都要擦除
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    // Rehash操作直到数组对应下标处元素为空的情况
    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i;
}

  逻辑会稍微复杂一些,我们还是一步一步看:

  • 获取内部保存Entry的数组及数组长度;
  • key为空代表这个Entry已经不需要了,直接置空,帮助gc,并将size减1;
  • 从传入的staleSlot下标后面的元素开始,依次遍历过去,循环执行下面的操作,直到遇到Entry为空停止;
  • 如果Entry为staleSlot(即key为null),则清空;
  • 否则检查该Entry是否在它应该在的位置(根据hashcode计算出来的下标与其实际下标是否相等);
  • 如果不在则将当前slot置为空,继续往后寻找,直到一个Entry为空的slot,将其放进去,重复下一次循环;

  expungeStaleEntry的作用是清除传入的staleSlot处的Entry,除此之外还会管两件"闲事":

  • 从其后面开始清除遇到的staleSlot;
  • rehash计算下标与实际下标不相符的Entry,
  • 直到遇到Entry为空的slot则停止。

  从上面的分析我们得出,通过key获取元素时,如果从计算出来的下标能获取到符合要求的值则直接返回,否则会从该位置开始依次往后找;遇到Entry不为空但是Entry的key为空的会擦除该Entry并继续循环;遇到Entry不为空且key不为空(hash碰撞)则直接往后找;在整个找的过程中遇到Entry为null则停止查找,直接返回null。

3. 设置元素 

  接下来我们看看设置元素,也就是set方法:

private void set(ThreadLocal<?> key, Object value) {

    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);

    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();
     // 找到则直接替换,然后直接返回
        if (k == key) {
            e.value = value;
            return;
        }
     // 发现staleSlot,则执行replaceStaleEntry,然后直接返回
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }
   // 如果没有找到,则new一个Entry插入数组中
    tab[i] = new Entry(key, value);
    int sz = ++size;
   // 插入新的Etry之后需要试探的去擦除一些过期的slot(key=null的Entry),如果Entry数量大于阈值,则执行扩容
if (!cleanSomeSlots(i, sz) && sz >= threshold) rehash(); }

  这也是一个私有方法,这里看起来代码不多,但是里面涉及到的东西很多,逻辑也要比get方法复杂,但是没关系,我们层层递进,一一分解。

  • 获取Entry数组、数组长度以及通过要插入的key的hashcode计算出其在数组中的下标;
  • 拿到下标之后,对应下标处如果有Entry存在,则有三种情况:
    • key不为空,且等于要插入的key,则直接将value替换成要执行的value,返回;
    • key为空,则执行replaceStaleEntry中的逻辑,返回;
    • 如果key不为空但是又不等于要插入的key,则取下标i处后一个元素,循环执行上面的操作;
  • 如果如上的循环结束,到这里代表没有找到要插入的key,且当前i处的Entry为空,则直接new一个Entry,将待插入的key和value放入其中,再放入数组;
  • 将代表数组中Entry数量的size加1;
  • 执行cleanSomeSlots中的逻辑,如果有删除一些Slot,并且size大于阈值,则需要执行rehash中的逻辑进行扩容,否则set执行结束;

  上面的步骤看完之后,我们来看看其中当key为空时需要执行的replaceStaleEntry的逻辑:

private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    // 现在staleSlot处对应的Entry其key=null,往前查找看是否能不能找到一个stale的Entry
    int slotToExpunge = staleSlot;
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;

    // Find either the key or trailing null slot of run, whichever
    // occurs first

    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();

        // 找到了直接替换,替换之后再尝试删除一些stale的Entry
        if (k == key) {
            e.value = value;

            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            // Start expunge at preceding stale entry if it exists
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }
        // 如果i处对应的Entry是stale,并且前面往前没有找到stale的Entry,则将i标识为待擦除的slot
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    // 如果没有找到传入key对应的entry,则new一个新Entry放在传入staleSlot下标处,现在staleSlot处的Entry不再是stale(过期的)了
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    // 如果还发现有其他stale entries存在, 将其清除
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

  这个replaceStaleEntry的逻辑比较难理解,只要清楚它主要干了下面两件事:

  • 尝试查找和传入key对应的Entry,找到则替换,没找到则在传入的staleSlot处插入一个新的Entry;
  • 在上面的过程中,尽力地去擦除一些找到的staleSlot;

  以及插入一个新的Entry之后,试探性地去删除多余的staleSlot(注意,是试探性的哦),逻辑在cleanSomeSlots中:

/**
* @param i 扫描起始下标,从第i+1处开始扫描
*
* @param n 扫描次数控制量,在往后面扫描的过程中,如果没有发现staleSlot,则最多扫描log2(n)个元素,否则在staleSlot之后再扫log2(table.length-1)个
**/
private
boolean cleanSomeSlots(int i, int n) {
   // 标识是否有删除过staleSlot
   boolean removed = false; Entry[] tab = table; int len = tab.length; do { i = nextIndex(i, len); Entry e = tab[i]; if (e != null && e.get() == null) { n = len; removed = true; i = expungeStaleEntry(i); } } while ( (n >>>= 1) != 0); return removed; }

  从i+1处开始,往后扫描,如果遇到staleSlot,则执行expungeStaleEntry,往后扫描log2(n)次结束循环,n为传入的参数,如果发现staleSlot,则将n更新为Entry数组长度len。

  这个设计非常巧妙,试探性的扫描一些单元看是否能发现staleSlot(不新鲜的entrys,也就是key=null)。当一个新元素添加进来或者一个staleSlot被清除的时候,会调用这个方法。该方法扫描元素的数量是对数级的,如果不扫描就不能及时清除key为null的entry(会浪费内存),如果全数组扫描则会导致一次插入的时间复杂度为O(n),采用这种试探性的扫描方式其实是一种在功能和性能之间的平衡,尽最大努力清理垃圾,又不导致过于消耗性能。

  如果插入了新Entry,且执行了cleanSomeSlots之后size的数量还是大于阈值的话,这时就需要rehash扩容了:

private void rehash() {
    expungeStaleEntries();

    // Use lower threshold for doubling to avoid hysteresis
    if (size >= threshold - threshold / 4)
        resize();
}

// 扫描全表,清除所有staleSlot
private void expungeStaleEntries() { Entry[] tab = table; int len = tab.length; for (int j = 0; j < len; j++) { Entry e = tab[j]; if (e != null && e.get() == null) expungeStaleEntry(j); } }
// 将表容量扩大一倍
private void resize() { Entry[] oldTab = table; int oldLen = oldTab.length; int newLen = oldLen * 2; Entry[] newTab = new Entry[newLen]; int count = 0; for (int j = 0; j < oldLen; ++j) { Entry e = oldTab[j]; if (e != null) { ThreadLocal<?> k = e.get(); if (k == null) { e.value = null; // Help the GC } else { int h = k.threadLocalHashCode & (newLen - 1); while (newTab[h] != null) h = nextIndex(h, newLen); newTab[h] = e; count++; } } } setThreshold(newLen); size = count; table = newTab; }

  首先扫描全表,清除所有staleSlot,如果这还不能减小size,则将table容量扩大一倍。扩容的逻辑比较简单,根据新数组容量来计算新的数组下标,如果存在hash冲突就往后找,直到Entry为空则把元素放进去。

  到这里我们学习了ThreadLocal的基本原理、核心数据结构、最常用的get和set方法,是不是对ThreadLocal有了更深入的了解呢?如果有,那非常高兴我的文章能给你带来一丁点价值^_^

4. 内存泄漏

  前面有讲到,ThreadLocalMap中的Entry其类型是属于弱引用(继承了WeakReference),被弱引用指向的对象,在下一次GC时是会被回收的,除非这个对象还有强引用指向它(对Java中强、软、弱、虚引用不清楚的同学可以详细了解下),之所以这样设计,我的理解是Entry是存在ThreadLocalMap中,而这个map又是保存在线程thread中的,用户是不能直接获取到的,也是不能直接操作的,也就会影响到垃圾回收。为了避免因为ThreadLocalMap存储了ThreadLocal对象而影响到ThreadLocal对象的垃圾回收,JDK的设计者把主动权完全交给调用方,一旦调用方不想使用,只需设置ThreadLocal对象为null,内存就可以被回收掉了,这也是弱引用的一个主要使用场景。

  另一方面,在set和getEntry的过程中会频繁的去清理stale entry,以及时释放空余位置,这样就可以及时清除value,因为value是我们要保存到ThreadLocal中的值,而这是强引用,即便是key被回收了,value依然不会被回收。

  虽然ThreadLocal中做了种种设计来防止内存泄漏,但是如果使用不当还是会导致内存泄漏,我这里借用一个网上的例子,一起来感受下:

public class ThreadLocalLeakDemo {
  
  public static void main(String[] args) {
    new Thread(new Runnable() {

      @Override
      public void run() {
        for(int i = 0; i< 1000 ;i++) {
          TestClass t = new TestClass(i);
          t.printId();
      // 行1,注释掉这一行时不会导致内存溢出 t
= null;
      // 行2,注释掉这一行时会导致内存溢出 t.threadLocal.remove(); } } }).start();; }
static class TestClass{ private int id; private int[] arr;
   // 注意,这是一个普通成员哦
private ThreadLocal<TestClass> threadLocal; TestClass(int id){ this.id = id; arr = new int[1000000]; threadLocal = new ThreadLocal(); threadLocal.set(this); } public void printId() { System.out.println(threadLocal.get().id); } } }

/**
* 注释行2,放开行1时,会导致内存溢出,结果如下:
**/
...
449
450
451
Exception in thread "Thread-0" java.lang.OutOfMemoryError: Java heap space
at testDemos.annotationDemos.ThreadLocalLeakDemo$TestClass.<init>(ThreadLocalLeakDemo.java:28)
at testDemos.annotationDemos.ThreadLocalLeakDemo$1.run(ThreadLocalLeakDemo.java:13)
at java.lang.Thread.run(Unknown Source)
...

/**
* 注释行1,放开行2时,不会导致内存泄漏,结果如下:
**/
...
997
998
999

  上面其实就是改了一行代码,就导致内存溢出,增加的那一步操作就是调用了ThreadLocal的remove,那我们就来看看remove的逻辑:

  移除元素的逻辑很简单,根据传入的key定位到数组下标i,从这个下标开始往后循环,直到遇到Entry为空时停止循环。如果找到key对应的entry,则调用Entry的clear方法。

private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            e.clear();
            expungeStaleEntry(i);
            return;
        }
    }
}

   结合上面的例子和源码,我们解释一下为什么没有调用remove方法会导致内存溢出。如上,在不调用remove时,每一次循环都会插入一个新的Entry对象到ThreadLocalMap中,这个Entry是指向一个新的ThreadLocal对象,对于这个ThreadLocal对象存在两个引用:

  • Entry-->ThreadLocal,这是弱引用;
  • Entry-->value(TestClass)-->ThreadLocal,这是强引用;

  由于强引用一直存在,而t=null并不能让value不可达,因为value是保存在线程本地内存中的,所以没法回收这个新的ThreadLocal对象,导致一直堆积,最终报OOM

  而如果调用remove的话,则会直接将对应Entry以及其保存的value清空,这样就不会内存泄漏了。

  其实上面的例子是使用不当导致的,如果将ThreadLocal成员变量置为static,也不会出现这个问题,因为即便有1000次循环,但是都是用的同一个ThreadLocal,在线程本地始终只有一份,用private static来修饰ThreadLocal也是一个官方推荐的惯用法。

5. 总结

  1. ThreadLocal内部数据结构:Entry数组
  2. Entry封装要保存的数据,以key-value的形式,key的类型为指向ThreadLocal的WeakReference,value为要保存的对象
  3. 通过key的hashcode来初步定位其在数组中的位置,如果没有则往后依次查找,如果找到则返回(getEntry)或替换(set),直到碰到为空的Entry为止,这就是解决hash碰撞所采用的方法;
  4. 当出现hash冲突时,ThreadLocalMap采用的办法就是继续往后面找,这是线性操作所以会比较低效。但是ThreadLocal采用的散列算法效果很好,冲突的概率非常小,再加上在set和getEntry的过程中会频繁的去清理stale entry(expungeStaleEntry、replaceStaleEntry、cleanSomeSlots中都有涉及到),是为了能够及时释放空余位置,进一步降低这种低效带来的影响。
  5. 由于Entry是指向ThreadLocal对象的弱引用,所以当ThreadLocal对象不存在强引用的时候,是可以被回收的,回收之后Entry就指向空了(get获取的key为null),但是这时候Entry中的value仍然不为空,可以可能导致内存泄漏,有两种方式可以清除:
  •   在ThreadLocal的get、set方法中会频繁的去清除staleSlot
  •   手动调用TreadLocal的remove方法来清除

  以上为个人总结,如有不对,烦请指正。

原文地址:https://www.cnblogs.com/volcano-liu/p/10712524.html