ThreadLocal原理及用法详解

背景

一直以来对ThreadLocal用法模棱两可,不知道怎么用今天好好研究了下给大家分享下。

  • 1、讲解ThreadLocal之前先回顾下什么是取模、x^y、弱引用。

1. 取模运算实际上是计算两数相除以后的余数.

假设a除以b的商是c,d是相对应的余数,那么几乎所有的计算机系统都满足a = c* b + d。所以d=a-b*c。
 17 % 10 的计算结果如下:d = (17) - (17 / 10) x 10 = (17) - (1 x 10) = 7
 -17 % 10 的计算结果如下:d = (-17) - (-17 / 10) x 10 = (-17) - (-1 x 10) = -7
-17 % -10 的计算结果如下:d = 17 - (17 / -10) x (-10) = (17) - (-1 x -10) = 7
-17 % -10 的计算结果如下:d= (-17) - (-17 / -10) x (-10) = (-17) - (1 x -10) = -7
可以看出:运算结果的符号始终和被模数的符号一致。

2. x^y 按位取异或。

如:x是二进制数0101 y是二进制数1011 则结果为x^y=1110,0^1=1,0^0=0,1^1=0,1^0=1!只要有一个为1就取值为1。

3. 弱引用

弱引用也是用来描述非必需对象的,当JVM进行垃圾回收时,无论内存是否充足,都会回收被弱引用关联的对象。在java中,用java.lang.ref.WeakReference类来表示。这里所说的被弱引用关联的对象是指只有弱引用与之关联,如果存在强引用同时与之关联,则进行垃圾回收时也不会回收该对象。

  • 2、我们要问两个问题

1.什么是ThreadLocal
ThreadLocal从字面意识可以理解为线程本地变量。也就是说如果定义了ThreadLocal,每个线程往这个ThreadLocal中读写是线程隔离,互相之间不会影响的。它提供了一种将可变数据通过每个线程有自己的独立副本从而实现线程封闭的机制。
2.实现的思路是什么
Tread类有一个类型为ThreadLocal.ThreadLocalMap的实例变量threadLocals,我们使用线程的时候有一个自己的ThreadLocalMap。ThreadLocalMap有自己的独立实现,可以简单认为ThreadLocal视为key,value为代码中放入的值(实际上key并不是ThreadLocal本身,通过源码可以知道它是一个弱引用)。调用ThreadLocal的set方法时候,都会存到ThreadLocalMap里面。调用ThreadLocal的get方法时候,在自己map里面找key,从而实现线程隔离。

  • 3、ThreadLocal最主要的实现在于ThreadLocalMap这个内部类里面,我们重点关注ThreadLocalMap这个类的用法,看看两位大师Josh Bloch and Doug Lea是什么设计出如此好的类。

  • 4、ThreadLocalMap

 
image.png

上面是ThreadLocalMap所有的API
ThreadLocalMap提供了一种为ThreadLocal定制的高效实现,并且自带一种基于弱引用的垃圾清理机制。

  1. 存储结构方面
    存储结构可以理解为一个map,但是不要和java.util.Map弄混。
 static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

从上面代码可以看出Entry便是ThreadLocalMap里定义的节点,它继承了WeakReference类,定义了一个类型为Object的value,用于存放塞到ThreadLocal里的值,key可以视为为ThreadLocal。

  1. 为什么要用弱引用,因为如果这里使用普通的key-value形式来定义存储结构,实质上就会造成节点的生命周期与线程强绑定,只要线程没有销毁,那么节点在GC分析中一直处于可达状态,没办法被回收,而程序本身也无法判断是否可以清理节点。弱引用是Java四种引用的的第三种(其它三种强引用、软引用、虚引用),比软引用更加弱一些,如果一个对象没有强引用链可达,那么一般活不过下一次GC。当某个ThreadLocal已经没有强引用可达,则随着它被垃圾回收,在ThreadLocalMap里对应的Entry的键值会失效,这为ThreadLocalMap本身的垃圾清理提供了便利。
  2. Entry里面的成员变量和方法
       /**
         * 初始容量,必须为2的幂.
         */
        private static final int INITIAL_CAPACITY = 16;

        /**
         * 根据需要调整大小。
         * 长度必须总是2的幂。
         */
        private Entry[] table;

        /**
         * 表中条目的数量。
         */
        private int size = 0;

        /**
         * 要调整大小的下一个大小值。默认为0
         */
        private int threshold; // Default to 0

        /**
         * 将调整大小阈值设置维持最坏2/3的负载因子。
         */
        private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }

        /**
         *上一个索引
         */
        private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
        }

        /**
         * 下一个索引
         */
        private static int prevIndex(int i, int len) {
            return ((i - 1 >= 0) ? i - 1 : len - 1);
        }

由于ThreadLocalMap使用线性探测法来解决散列冲突,所以实际上Entry[]数组在程序逻辑上是作为一个环形存在的。

  1. 构造函数
         /**
         * Construct a new map initially containing (firstKey, firstValue).
         * ThreadLocalMaps are constructed lazily, so we only create
         * one when we have at least one entry to put in it.
         * 构造一个最初包含(firstKey, firstValue)的新映射threadlocalmap是延迟构造的,因      
         *  此当我们至少有一个元素可以放进去的时候才去创建。
         */
        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }

重点说下这个hash函数int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
ThreadLocal类中有一个被final修饰的类型为int的threadLocalHashCode,它在该ThreadLocal被构造的时候就会生成,相当于一个ThreadLocal的ID,而它的值来源于

    private final int threadLocalHashCode = nextHashCode();

    /**
     * The next hash code to be given out. Updated atomically. Starts at
     * zero.
     */
    private static AtomicInteger nextHashCode =
        new AtomicInteger();

    /**
     * 连续生成的哈希码之间的区别——循环隐式顺序线程本地id以近乎最优的方式展开
     * 用于两倍大小表的乘法哈希值。
     */
    private static final int HASH_INCREMENT = 0x61c88647;

    /**
     * 返回下一个hashcode
     */
    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

通过理论和实践算出通,当我们用0x61c88647作为魔数累加为每个ThreadLocal分配各自的ID也就是threadLocalHashCode再与2的幂取模,得到的结果分布很均匀。

ThreadLocalMap使用的是线性探测法,均匀分布的好处在于很快就能探测到下一个临近的可用slot,从而保证效率。这就回答了上文抛出的为什么大小要为2的幂的问题。为了优化效率。对于& (INITIAL_CAPACITY - 1),相信有过算法阅读源码较多的程序员,一看就明白,对于2的幂作为模数取模,可以用&(2n-1)来替代%2n,位运算比取模效率高很多。至于为什么,因为对2^n取模,只要不是低n位对结果的贡献显然都是0,会影响结果的只能是低n位。
  1. TreadLocal中的get方法
    ThreadLocal中的get方法会调用这个ThreadLocalMap中的getEntry方法,
      private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;
            else
                return getEntryAfterMiss(key, i, e);
        }

        /**
         * Version of getEntry method for use when key is not found in
         * its direct hash slot.
         *
         * @param  key the thread local object
         * @param  i the table index for key's hash code
         * @param  e the entry at table[i]
         * @return the entry associated with key, or null if no such
         */
        private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)
                    return e;
                if (k == null)
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }
 private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // expunge entry at staleSlot
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;

                        // Unlike Knuth 6.4 Algorithm R, we must scan until
                        // null because multiple entries could have been stale.
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        } private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // expunge entry at staleSlot
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;

                        // Unlike Knuth 6.4 Algorithm R, we must scan until
                        // null because multiple entries could have been stale.
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }

这上面的注释很清楚,这里都不在多讲了

简单说下当调用get方法时候遇到的情况

根据入参threadLocal的threadLocalHashCode对表容量取模得到index如果index对应的slot就是要读的threadLocal,则直接返回结果
调用getEntryAfterMiss线性探测,过程中每碰到无效slot,调用expungeStaleEntry进行段清理;如果找到了key,则返回结果entry
没有找到key,返回null

  1. TreadLocal中的set方法
    /**
         * Set the value associated with key.
         *
         * @param key the thread local object
         * @param value the value to be set
         */
        private void set(ThreadLocal<?> key, Object value) {

            // We don't use a fast path as with get() because it is at
            // least as common to use set() to create new entries as
            // it is to replace existing ones, in which case, a fast
            // path would fail more often than not.

            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);

            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();

                if (k == key) {
                    e.value = value;
                    return;
                }

                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }

            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

      /**
         * Replace a stale entry encountered during a set operation
         * with an entry for the specified key.  The value passed in
         * the value parameter is stored in the entry, whether or not
         * an entry already exists for the specified key.
         *
         * As a side effect, this method expunges all stale entries in the
         * "run" containing the stale entry.  (A run is a sequence of entries
         * between two null slots.)
         *
         * @param  key the key
         * @param  value the value to be associated with key
         * @param  staleSlot index of the first stale entry encountered while
         *         searching for key.
         */
        private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;

            // Back up to check for prior stale entry in current run.
            // We clean out whole runs at a time to avoid continual
            // incremental rehashing due to garbage collector freeing
            // up refs in bunches (i.e., whenever the collector runs).
            int slotToExpunge = staleSlot;
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))
                if (e.get() == null)
                    slotToExpunge = i;

            // Find either the key or trailing null slot of run, whichever
            // occurs first
            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();

                // If we find key, then we need to swap it
                // with the stale entry to maintain hash table order.
                // The newly stale slot, or any other stale slot
                // encountered above it, can then be sent to expungeStaleEntry
                // to remove or rehash all of the other entries in run.
                if (k == key) {
                    e.value = value;

                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;

                    // Start expunge at preceding stale entry if it exists
                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }

                // If we didn't find stale entry on backward scan, the
                // first stale entry seen while scanning for key is the
                // first still present in the run.
                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }

            // If key not found, put new entry in stale slot
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);

            // If there are any other stale entries in run, expunge them
            if (slotToExpunge != staleSlot)
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }

        /**
         * Heuristically scan some cells looking for stale entries.
         * This is invoked when either a new element is added, or
         * another stale one has been expunged. It performs a
         * logarithmic number of scans, as a balance between no
         * scanning (fast but retains garbage) and a number of scans
         * proportional to number of elements, that would find all
         * garbage but would cause some insertions to take O(n) time.
         *
         * @param i a position known NOT to hold a stale entry. The
         * scan starts at the element after i.
         *
         * @param n scan control: {@code log2(n)} cells are scanned,
         * unless a stale entry is found, in which case
         * {@code log2(table.length)-1} additional cells are scanned.
         * When called from insertions, this parameter is the number
         * of elements, but when from replaceStaleEntry, it is the
         * table length. (Note: all this could be changed to be either
         * more or less aggressive by weighting n instead of just
         * using straight log n. But this version is simple, fast, and
         * seems to work well.)
         *
         * @return true if any stale entries have been removed.
         */
        private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                i = nextIndex(i, len);
                Entry e = tab[i];
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    i = expungeStaleEntry(i);
                }
            } while ( (n >>>= 1) != 0);
            return removed;
        }

    private void rehash() {
            expungeStaleEntries();

            // Use lower threshold for doubling to avoid hysteresis
            if (size >= threshold - threshold / 4)
                resize();
        }

        /**
         * Double the capacity of the table.
         */
        private void resize() {
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            int newLen = oldLen * 2;
            Entry[] newTab = new Entry[newLen];
            int count = 0;

            for (int j = 0; j < oldLen; ++j) {
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal<?> k = e.get();
                    if (k == null) {
                        e.value = null; // Help the GC
                    } else {
                        int h = k.threadLocalHashCode & (newLen - 1);
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        newTab[h] = e;
                        count++;
                    }
                }
            }

            setThreshold(newLen);
            size = count;
            table = newTab;
        }

        /**
         * Expunge all stale entries in the table.
         */
        private void expungeStaleEntries() {
            Entry[] tab = table;
            int len = tab.length;
            for (int j = 0; j < len; j++) {
                Entry e = tab[j];
                if (e != null && e.get() == null)
                    expungeStaleEntry(j);
            }
        }
简单总结下上面的用法

a、探测过程中slot都不无效,并且顺利找到key所在的slot,直接替换即可
b、探测过程中发现有无效slot,调用replaceStaleEntry,效果是最终一定会把key和value放在这个slot,并且会尽可能清理无效slot
在replaceStaleEntry过程中,如果找到了key,则做一个swap把它放到那个无效slot中,value置为新值
在replaceStaleEntry过程中,没有找到key,直接在无效slot原地放entry
c、探测没有发现key,则在连续段末尾的后一个空位置放上entry,这也是线性探测法的一部分。放完后,做一次启发式清理,如果没清理出去key,并且当前table大小已经超过阈值了,则做一次rehash,rehash函数会调用一次全量清理slot方法也expungeStaleEntries,如果完了之后table大小超过了threshold - threshold / 4,则进行扩容2倍
6、ThreadLocal中的remove方法调用TreadLocalMap中的remove

 /**
         * Remove the entry for key.
         */
        private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    e.clear();
                    expungeStaleEntry(i);
                    return;
                }
            }
        }

这个方法比较简单,找到key就清除

  1. ThreadLocal会不会遇到内存泄漏问题
    有关于内存泄露是因为在有线程复用如线程池的场景中,一个线程的生命周期很长,大对象长期不被回收影响系统运行效率与安全。如果线程不会复用,用完即销毁了也不会有ThreadLocal引发内存泄露的问题。
    仔细读过ThreadLocalMap的源码,可以推断,如果在使用的ThreadLocal的过程中,习惯加上remove,这样是不会引起内存泄漏。
    如果没有进行remove呢?如果对应线程之后调用ThreadLocal的get和set方法都有很高的概率会顺便清理掉无效对象,断开value强引用,从而大对象被收集器回收。
    我们应该考虑到何时调用ThreadLocal的remove方法。一个比较熟悉的场景就是对于一个请求一个线程的server如tomcat,在代码中对web api作一个切面,存放一些如用户名等用户信息,在连接点方法结束后,再显式调用remove。

以上都是理论,我们做一个小实验

/**
 * @author shuliangzhao
 * @Title: ThreadLocalDemo
 * @ProjectName design-parent
 * @Description: TODO
 * @date 2019/6/1 0:00
 */
public class ThreadLocalDemo {
    private static ThreadLocal<ThreadLocalDemo> t = new ThreadLocal<>();
    private ThreadLocalDemo() {}

    public static ThreadLocalDemo getInstance() {
        ThreadLocalDemo threadLocalDemo = ThreadLocalDemo.t.get();
        if (null == threadLocalDemo) {
            threadLocalDemo = new ThreadLocalDemo();
            t.set(threadLocalDemo);
        }
        return threadLocalDemo;
    }
    private String name;

    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public static void main(String[] args) {
      long i =  (1L << 32) - (long) ((1L << 31) * (Math.sqrt(5) - 1));
        System.out.println(i);
        int i1 = 55&(2^2-1);
        System.out.println(55%2^2);
        System.out.println(i1);
    }
}
/**
 * @author shuliangzhao
 * @Title: ThreadLocalTest
 * @ProjectName design-parent
 * @Description: TODO
 * @date 2019/6/1 0:03
 */
public class ThreadLocalTest {
    public static void main(String[] args) {
        for(int i=0; i<2;i++){
            new Thread(new Runnable() {
                @Override
                public void run() {
                    Double d = Math.random()*10;
                    ThreadLocalDemo.getInstance().setName("name "+d);
                    new A().get();
                    new B().get();
                }
            }).start();
        }
    }
    static class A{
        public void get(){
            System.out.println(ThreadLocalDemo.getInstance().getName());
        }
    }
    static class B{
        public void get(){
            System.out.println(ThreadLocalDemo.getInstance().getName());
        }
    }

}

运行结果

 
image.png

到这里我们就把ThreadLocal讲完了,可以多看看源码,看看大牛们是怎么设计出如此优美的代码。

原文地址:https://www.cnblogs.com/treeshu/p/10959611.html