ThreadLocal 源码分析

线程局部变量

ThreadLocal 用于实现线程隔离和类间变量共享。

创建实例

    /**
     * 当前 ThreadLocal 实例的哈希值
     */
    private final int threadLocalHashCode = nextHashCode();

    /**
     * 下一个 ThreadLocal 实例的哈希值
     */
    private static AtomicInteger nextHashCode =
            new AtomicInteger();

    /**
     * 下一个 ThreadLocal 实例的哈希值增量,可以最大程度地避免碰撞
     */
    private static final int HASH_INCREMENT = 0x61c88647;

    /**
     * 读取下一个哈希值
     */
    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

    /**
     * 创建一个线程局部变量
     */
    public ThreadLocal() {
    }

写入

    /**
     * 将目标值 value 写入此 ThreadLocal 变量中
     */
    public void set(T value) {
        // 读取当前线程
        final Thread t = Thread.currentThread();
        // 读取与当前线程绑定的 ThreadLocalMap
        final ThreadLocalMap map = getMap(t);
        // 1)ThreadLocalMap 已经存在,则直接写入
        if (map != null) {
            map.set(this, value);
        } else {
        // 2)创建 ThreadLocalMap 并写入  
            createMap(t, value);
        }
    }

读取

    /**
     * 读取此 ThreadLocal 实例关联的值
     */
    public T get() {
        final Thread t = Thread.currentThread();
        final ThreadLocalMap map = getMap(t);
        if (map != null) {
            // 读取 Entry
            final ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                final
                // 读取值
                T result = (T)e.value;
                return result;
            }
        }
        // 此 ThreadLocal 还未初始化,则写入初始值
        return setInitialValue();
    }

移除

    /**
     * 移除此 ThreadLocal
     */
    public void remove() {
        final ThreadLocalMap m = getMap(Thread.currentThread());
        if (m != null) {
            m.remove(this);
        }
    }

ThreadLocalMap

    static class ThreadLocalMap {
        /**
         * ThreadLocalMap 的 Entry 继承了 WeakReference,以便能处理大量的条目,
         * 当 entry.get()==null 时,表示关联的 ThreadLocal 对象已经被回收,该条目
         * 可以从此 ThreadLocalMap 中移除了
         */
        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** 与  ThreadLocal 关联的值 */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

        /**
         * table 的初始容量
         */
        private static final int INITIAL_CAPACITY = 16;

        /**
         * 容纳键值对的底层 table
         */
        private Entry[] table;

        /**
         * table 中的条目数
         */
        private int size = 0;

        /**
         * 下一次扩容的阈值
         */
        private int threshold; // Default to 0

        /**
         * 写入下次扩容的阈值,table 容量的 2/3
         */
        private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }

        /**
         * 索引递增 1
         */
        private static int nextIndex(int i, int len) {
            return i + 1 < len ? i + 1 : 0;
        }

        /**
         * 索引递减 1
         */
        private static int prevIndex(int i, int len) {
            return i - 1 >= 0 ? i - 1 : len - 1;
        }

        /**
         * 创建 ThreadLocalMap 实例,并写入第一个 entry
         */
        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            // 创建初始容量为 16 的 table
            table = new Entry[INITIAL_CAPACITY];
            // 第一个键的 threadLocalHashCode 为 0,即目标索引也为 0
            final int i = firstKey.threadLocalHashCode & INITIAL_CAPACITY - 1;
            // 创建并写入 Entry
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            // 写入阈值,table 容量的 2/3
            setThreshold(INITIAL_CAPACITY);
        }

        /**
         * Construct a new map including all Inheritable ThreadLocals
         * from given parent map. Called only by createInheritedMap.
         *
         * @param parentMap the map associated with parent thread.
         */
        private ThreadLocalMap(ThreadLocalMap parentMap) {
            final Entry[] parentTable = parentMap.table;
            final int len = parentTable.length;
            setThreshold(len);
            table = new Entry[len];

            for (final Entry e : parentTable) {
                if (e != null) {
                    @SuppressWarnings("unchecked")
                    final
                    ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                    if (key != null) {
                        final Object value = key.childValue(e.value);
                        final Entry c = new Entry(key, value);
                        int h = key.threadLocalHashCode & len - 1;
                        while (table[h] != null) {
                            h = nextIndex(h, len);
                        }
                        table[h] = c;
                        size++;
                    }
                }
            }
        }

        /**
         * 读取指定 key 关联的 Entry
         */
        private Entry getEntry(ThreadLocal<?> key) {
            // 计算索引值
            final int i = key.threadLocalHashCode & table.length - 1;
            final Entry e = table[i];
            // 读取 Entry 的键和目标 key 相等,则直接返回
            if (e != null && e.get() == key) {
                return e;
            } else {
                // 往后查找 Entry
                return getEntryAfterMiss(key, i, e);
            }
        }

        private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            final Entry[] tab = table;
            final int len = tab.length;

            while (e != null) {
                final ThreadLocal<?> k = e.get();
                // 当前 Entry 的键和目标 key 相等,则返回
                if (k == key) {
                    return e;
                }
                // 1)当前弱键已经被回收,则移除此 Entry
                if (k == null) {
                    expungeStaleEntry(i);
                } else {
                    // 2)索引值递增 1,循环查找
                    i = nextIndex(i, len);
                }
                // 定位下一个 Entry
                e = tab[i];
            }
            // 未找到则返回 null
            return null;
        }

        /**
         * 将指定的 ThreadLocal 键值对写入此 ThreadLocalMap 中
         */
        private void set(ThreadLocal<?> key, Object value) {
            /**
             * We don't use a fast path as with get() because it is at
             * least as common to use set() to create new entries as
             * it is to replace existing ones, in which case, a fast
             * path would fail more often than not.
             */
            // 读取 table
            final Entry[] tab = table;
            // 读取 length
            final int len = tab.length;
            // 基于 ThreadLocal 的哈希值计算索引
            int i = key.threadLocalHashCode & len-1;
            /**
             * 读取目标索引处的 Entry && 此 Entry 不为 null
             * 1)如果当前 Entry 不匹配,则循环查找下一个 Entry,直到下一个 Entry 为空
             */
            for (Entry e = tab[i];
                    e != null;
                    e = tab[i = nextIndex(i, len)]) {
                // 读取 Entry 的 ThreadLocal
                final ThreadLocal<?> k = e.get();
                // 当前 ThreadLocal 与目标 ThreadLocal 相等,则更新其值
                if (k == key) {
                    e.value = value;
                    return;
                }
                // 如果弱键已经被回收,则移除过时的 Entry
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
            // 未找到,则创建并写入新的 Entry
            tab[i] = new Entry(key, value);
            // 递增 size
            final int sz = ++size;
            /**
             * 1)尝试清除索引 i 之后的一些 slot,如果清除成功,则此 table 无需扩容
             * 2)slot 清除失败 && 判断当前总元素数是否超出扩容阈值 && 超出则进行扩容
             */
            if (!cleanSomeSlots(i, sz) && sz >= threshold) {
                rehash();
            }
        }

        /**
         * 移除指定 key 关联的 Entry
         */
        private void remove(ThreadLocal<?> key) {
            final Entry[] tab = table;
            final int len = tab.length;
            int i = key.threadLocalHashCode & len-1;
            for (Entry e = tab[i];
                    e != null;
                    e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    // 清除弱引用
                    e.clear();
                    // 清除此 Entry
                    expungeStaleEntry(i);
                    return;
                }
            }
        }

        private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                int staleSlot) {
            final Entry[] tab = table;
            final int len = tab.length;
            Entry e;

            /**
             * 循环查找过时的 Entry,并尝试更新索引,直到遇到一个空 slot 为止
             */
            int slotToExpunge = staleSlot;
            for (int i = prevIndex(staleSlot, len);
                    (e = tab[i]) != null;
                    i = prevIndex(i, len)) {
                if (e.get() == null) {
                    slotToExpunge = i;
                }
            }

            // Find either the key or trailing null slot of run, whichever occurs first
            for (int i = nextIndex(staleSlot, len);
                    (e = tab[i]) != null;
                    i = nextIndex(i, len)) {
                final ThreadLocal<?> k = e.get();

                /**
                 * If we find key, then we need to swap it with the stale entry to maintain hash table order.
                 * The newly stale slot, or any other stale slot encountered above it,
                 * can then be sent to expungeStaleEntry to remove or rehash all of the other entries in run.
                 */
                if (k == key) {
                    e.value = value;

                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;

                    // Start expunge at preceding stale entry if it exists
                    if (slotToExpunge == staleSlot) {
                        slotToExpunge = i;
                    }
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }

                // If we didn't find stale entry on backward scan, the
                // first stale entry seen while scanning for key is the
                // first still present in the run.
                if (k == null && slotToExpunge == staleSlot) {
                    slotToExpunge = i;
                }
            }

            // If key not found, put new entry in stale slot
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);

            // If there are any other stale entries in run, expunge them
            if (slotToExpunge != staleSlot) {
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            }
        }

        private int expungeStaleEntry(int staleSlot) {
            final Entry[] tab = table;
            final int len = tab.length;

            // 移除 staleSlot 索引处 ThreadLocal 关联的值
            tab[staleSlot].value = null;
            // 移除 staleSlot 的 Entry
            tab[staleSlot] = null;
            // 递减 size
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            // 计算被异常条目之后的 index && 定位的 Entry 不为 null
            for (i = nextIndex(staleSlot, len);
                    (e = tab[i]) != null;
                    i = nextIndex(i, len)) {
                // 读取键
                final ThreadLocal<?> k = e.get();
                // 此 Entry 关联的弱键已经被回收
                if (k == null) {
                    // 清除值
                    e.value = null;
                    // 清除 Entry
                    tab[i] = null;
                    // 递减 size
                    size--;
                } else {
                    // 计算索引
                    int h = k.threadLocalHashCode & len - 1;
                    // 如果 table 已经扩容
                    if (h != i) {
                        // 将旧的 Entry 置为 null
                        tab[i] = null;

                        // 从新的索引位置开始寻找一个空的 slot
                        while (tab[h] != null) {
                            h = nextIndex(h, len);
                        }
                        // 写入 Entry
                        tab[h] = e;
                    }
                }
            }
            return i;
        }


        private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            final Entry[] tab = table;
            final int len = tab.length;
            do {
                // 计算下一个索引
                i = nextIndex(i, len);
                // 读取 Entry
                final Entry e = tab[i];
                // 如果此 Entry 条目不为空 && 条目关联的值为 null
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    // 移除过时的 Entry
                    i = expungeStaleEntry(i);
                }
                // 将 n 减半
            } while ( (n >>>= 1) != 0);
            return removed;
        }

        private void rehash() {
            expungeStaleEntries();

            // 当前元素总数 >= 阈值的 3/4
            if (size >= threshold - threshold / 4) {
                resize();
            }
        }

        /**
         * 双倍扩容
         */
        private void resize() {
            final Entry[] oldTab = table;
            final int oldLen = oldTab.length;
            final int newLen = oldLen * 2;
            final Entry[] newTab = new Entry[newLen];
            int count = 0;

            for (final Entry e : oldTab) {
                if (e != null) {
                    final ThreadLocal<?> k = e.get();
                    // 1)弱键已经被移除,则清空其值
                    if (k == null) {
                        e.value = null; // Help the GC
                    } else {
                        // 2)基于当前 ThreadLocal 的哈希值定位索引,顺序找到第一个可用的 slot
                        int h = k.threadLocalHashCode & newLen - 1;
                        while (newTab[h] != null) {
                            h = nextIndex(h, newLen);
                        }
                        // 写入 Entry
                        newTab[h] = e;
                        // 递增元素数
                        count++;
                    }
                }
            }
            setThreshold(newLen);
            size = count;
            table = newTab;
        }

        /**
         * 清除所有过时的 Entry
         */
        private void expungeStaleEntries() {
            final Entry[] tab = table;
            final int len = tab.length;
            for (int j = 0; j < len; j++) {
                final Entry e = tab[j];
                if (e != null && e.get() == null) {
                    expungeStaleEntry(j);
                }
            }
        }
    }
原文地址:https://www.cnblogs.com/zhuxudong/p/10210893.html