清理ThreadLocal

在我很多的课程里(masterconcurrencyxj-conc-j8),我经常提起ThreadLocal。它经常受到我严厉的指责要尽可能的避免使用。ThreadLocal是为了那些使用完就销毁的线程设计的。线程生成之前,线程内的局部变量都会被清除掉。实际上,如果你读过 Why 0x61c88647?,这篇文章中解释了实际的值是存在一个内部的map中,这个map是伴随着线程的产生而产生的。存在于线程池的线程不会只存活于单个用户请求,这很容易导致内存溢出。通常我在讨论这个的时候,至少有一位学生有过因为在一个线程局部变量持有某个类而导致整个生产系统奔溃。因此,预防实体类加载后不被卸载,是一个非常普遍的问题。

在这篇文章中,我将演示一个ThreadLocalCleaner类。通过这个类,可以在线程回到线程池之前,恢复所有本地线程变量到最开始的状态。最基础的,我们可以保存当前线程的ThreadLocal的状态,之后再重置。我们可以使用Java 7提供的try-with-resource结构来完成这件事情。例如:

try (ThreadLocalCleaner tlc = new ThreadLocalCleaner()) {
  // some code that potentially creates and adds thread locals
}
// at this point, the new thread locals have been cleared

为了简化调试,我们增加一个观察者的机制,这样我们能够监测到线程局部map发生的任何变化。这能帮助我们发现可能出现泄漏的线程局部变量。这是我们的监听器:

package threadcleaner;
@FunctionalInterface
public interface ThreadLocalChangeListener {
  void changed(Mode mode, Thread thread,
               ThreadLocal<?> threadLocal, Object value);

  ThreadLocalChangeListener EMPTY = (m, t, tl, v) -> {};

  ThreadLocalChangeListener PRINTER =
      (m, t, tl, v) -> System.out.printf(
          "Thread %s %s ThreadLocal %s with value %s%n",
          t, m, tl.getClass(), v);

  enum Mode {
    ADDED, REMOVED
  }
}

这个地方可能需要做一下必要的说明。首先,我添加了注解@FunctionalInterface,这个注解是Java 8提供的,它的意思是该类只有一个抽象方法,可以作为lambda表达式使用。其次,我在该类的内部定义了一个EMPTY的lambda表达式。这样,你可以见识到,这段代码会非常短小。第三,我还定义了一个默认的PRINTER,它可以简单的通过System.out输出改变的信息。最后,我们还有两个不不同的事件,但是因为想设计成为一个函数式编程接口(@FunctionalInterface),我不得不把这个标示定义为单独的属性,这里定义成了枚举。

当我们构造ThreadLocalCleaner时,我们可以传递一个ThreadLocalChangeListener。这样,从Treadlocal创建开始发生的任何变化,我们都能监测到。请注意,这种机制只适合于当前线程。这有一个例子演示我们怎样通过try-with-resource代码块来使用ThreadLocalCleaner:任何定义在在 try(…) 中的局部变量,都会在代码块的自后进行自动关闭。因此,我们需要在ThreadLocalCleaner内部有一个 close() 方法,用于恢复线程局部变量到初始值。

import java.text.*;
public class ThreadLocalCleanerExample {
  private static final ThreadLocal df =
      new ThreadLocal() {
        protected DateFormat initialValue() {
          return new SimpleDateFormat("yyyy-MM-dd");
        }
      };

  public static void main(String... args) {
    System.out.println("First ThreadLocalCleaner context");
    try (ThreadLocalCleaner tlc = new ThreadLocalCleaner(
        ThreadLocalChangeListener.PRINTER)) {
      System.out.println(System.identityHashCode(df.get()));
      System.out.println(System.identityHashCode(df.get()));
      System.out.println(System.identityHashCode(df.get()));
    }

    System.out.println("Another ThreadLocalCleaner context");
    try (ThreadLocalCleaner tlc = new ThreadLocalCleaner(
        ThreadLocalChangeListener.PRINTER)) {
      System.out.println(System.identityHashCode(df.get()));
      System.out.println(System.identityHashCode(df.get()));
      System.out.println(System.identityHashCode(df.get()));
    }
  }
}

在ThreadLocalCleaner类中还有两个公共的静态方法:forEach() 和 cleanup(Thread)。forEach() 方法有两个参数:Thread和BiConsumer。该方法通过ThreadLocal调用来遍历其中的每一个值。我们跳过了key为null的对象,但是没有跳过值为null的对象。理由是如果仅仅是值为null,ThreadLocal仍然有可能出现内存泄露。一旦我们使用完了ThreadLocal,就应该在将线程返回给线程池之前调用 remove() 方法。cleanup(Thread) 方法设置ThreadLocal的map在该线程内为null,因此,允许垃圾回收器回收所有的对象。如果一个ThreadLocal在我们清理后再次使用,就简单的调用 initialValue() 方法来创建对象。这是方法的定义:

public static void forEach(Thread thread,
      BiConsumer<ThreadLocal<?>, Object> consumer) { ... }

public static void cleanup(Thread thread) { ... }

ThreadLocalCleaner类完整的代码如下。该类使用了许多反射来操作私有域。它可能只能在OpenJDK或其直接衍生产品上运行。你也能注意到我使用了Java 8的语法。我纠结过很长一段时间是否使用Java 8 或 7。我的某些客户端还在使用1.4。最后,我的大部分大银行客户已经在产品中开始使用Java 8。银行通常来说不是最先采用的新技术人,除非存在非常大的经济意义。因此,如果你还没有在产品中使用Java 8,你应该尽可能快的移植过去,甚至可以跳过Java 8,直接到Java 9。你应该可以很容易的反向移植到Java 7上,只需要自定义一个BiConsumer接口。Java 6不支持try-with-resource结构,所以反向移植会比较困难一点。

package threadcleaner;

import java.lang.ref.*;
import java.lang.reflect.*;
import java.util.*;
import java.util.function.*;

import static threadcleaner.ThreadLocalChangeListener.Mode.*;

public class ThreadLocalCleaner implements AutoCloseable {
  private final ThreadLocalChangeListener listener;

  public ThreadLocalCleaner() {
    this(ThreadLocalChangeListener.EMPTY);
  }

  public ThreadLocalCleaner(ThreadLocalChangeListener listener) {
    this.listener = listener;
    saveOldThreadLocals();
  }

  public void close() {
    cleanup();
  }

  public void cleanup() {
    diff(threadLocalsField, copyOfThreadLocals.get());
    diff(inheritableThreadLocalsField,
        copyOfInheritableThreadLocals.get());
    restoreOldThreadLocals();
  }

  public static void forEach(
      Thread thread,
      BiConsumer<ThreadLocal<?>, Object> consumer) {
    forEach(thread, threadLocalsField, consumer);
    forEach(thread, inheritableThreadLocalsField, consumer);
  }

  public static void cleanup(Thread thread) {
    try {
      threadLocalsField.set(thread, null);
      inheritableThreadLocalsField.set(thread, null);
    } catch (IllegalAccessException e) {
      throw new IllegalStateException(
          "Could not clear thread locals: " + e);
    }
  }

  private void diff(Field field, Reference<?>[] backup) {
    try {
      Thread thread = Thread.currentThread();
      Object threadLocals = field.get(thread);
      if (threadLocals == null) {
        if (backup != null) {
          for (Reference<?> reference : backup) {
            changed(thread, reference,
                REMOVED);
          }
        }
        return;
      }

      Reference<?>[] current =
          (Reference<?>[]) tableField.get(threadLocals);
      if (backup == null) {
        for (Reference<?> reference : current) {
          changed(thread, reference, ADDED);
        }
      } else {
        // nested loop - both arrays *should* be relatively small
        next:
        for (Reference<?> curRef : current) {
          if (curRef != null) {
            if (curRef.get() == copyOfThreadLocals ||
                curRef.get() == copyOfInheritableThreadLocals) {
              continue next;
            }
            for (Reference<?> backupRef : backup) {
              if (curRef == backupRef) continue next;
            }
            // could not find it in backup - added
            changed(thread, curRef, ADDED);
          }
        }
        next:
        for (Reference<?> backupRef : backup) {
          for (Reference<?> curRef : current) {
            if (curRef == backupRef) continue next;
          }
          // could not find it in current - removed
          changed(thread, backupRef,
              REMOVED);
        }
      }
    } catch (IllegalAccessException e) {
      throw new IllegalStateException("Access denied", e);
    }
  }

  private void changed(Thread thread, Reference<?> reference,
                       ThreadLocalChangeListener.Mode mode)
      throws IllegalAccessException {
    listener.changed(mode,
        thread, (ThreadLocal<?>) reference.get(),
        threadLocalEntryValueField.get(reference));
  }

  private static Field field(Class<?> c, String name)
      throws NoSuchFieldException {
    Field field = c.getDeclaredField(name);
    field.setAccessible(true);
    return field;
  }

  private static Class<?> inner(Class<?> clazz, String name) {
    for (Class<?> c : clazz.getDeclaredClasses()) {
      if (c.getSimpleName().equals(name)) {
        return c;
      }
    }
    throw new IllegalStateException(
        "Could not find inner class " + name + " in " + clazz);
  }

  private static void forEach(
      Thread thread, Field field,
      BiConsumer<ThreadLocal<?>, Object> consumer) {
    try {
      Object threadLocals = field.get(thread);
      if (threadLocals != null) {
        Reference<?>[] table = (Reference<?>[])
            tableField.get(threadLocals);
        for (Reference<?> ref : table) {
          if (ref != null) {
            ThreadLocal<?> key = (ThreadLocal<?>) ref.get();
            if (key != null) {
              Object value = threadLocalEntryValueField.get(ref);
              consumer.accept(key, value);
            }
          }
        }
      }
    } catch (IllegalAccessException e) {
      throw new IllegalStateException(e);
    }
  }

  private static final ThreadLocal<Reference<?>[]>
      copyOfThreadLocals = new ThreadLocal<>();

  private static final ThreadLocal<Reference<?>[]>
      copyOfInheritableThreadLocals = new ThreadLocal<>();

  private static void saveOldThreadLocals() {
    copyOfThreadLocals.set(copy(threadLocalsField));
    copyOfInheritableThreadLocals.set(
        copy(inheritableThreadLocalsField));
  }

  private static Reference<?>[] copy(Field field) {
    try {
      Thread thread = Thread.currentThread();
      Object threadLocals = field.get(thread);
      if (threadLocals == null) return null;
      Reference<?>[] table =
          (Reference<?>[]) tableField.get(threadLocals);
      return Arrays.copyOf(table, table.length);
    } catch (IllegalAccessException e) {
      throw new IllegalStateException("Access denied", e);
    }
  }

  private static void restoreOldThreadLocals() {
    try {
      restore(threadLocalsField, copyOfThreadLocals.get());
      restore(inheritableThreadLocalsField,
          copyOfInheritableThreadLocals.get());
    } finally {
      copyOfThreadLocals.remove();
      copyOfInheritableThreadLocals.remove();
    }
  }

  private static void restore(Field field, Object value) {
    try {
      Thread thread = Thread.currentThread();
      if (value == null) {
        field.set(thread, null);
      } else {
        tableField.set(field.get(thread), value);
      }
    } catch (IllegalAccessException e) {
      throw new IllegalStateException("Access denied", e);
    }
  }

  /* Reflection fields */

  private static final Field threadLocalsField;

  private static final Field inheritableThreadLocalsField;
  private static final Class<?> threadLocalMapClass;
  private static final Field tableField;
  private static final Class<?> threadLocalMapEntryClass;

  private static final Field threadLocalEntryValueField;

  static {
    try {
      threadLocalsField = field(Thread.class, "threadLocals");
      inheritableThreadLocalsField =
          field(Thread.class, "inheritableThreadLocals");

      threadLocalMapClass =
          inner(ThreadLocal.class, "ThreadLocalMap");

      tableField = field(threadLocalMapClass, "table");
      threadLocalMapEntryClass =
          inner(threadLocalMapClass, "Entry");

      threadLocalEntryValueField =
          field(threadLocalMapEntryClass, "value");
    } catch (NoSuchFieldException e) {
      throw new IllegalStateException(
          "Could not locate threadLocals field in Thread.  " +
              "Will not be able to clear thread locals: " + e);
    }
  }
}

这是一个ThreadLocalCleaner在实践中应用的例子:

import java.text.*;

public class ThreadLocalCleanerExample {
  private static final ThreadLocal<DateFormat> df =
      new ThreadLocal<DateFormat>() {
        protected DateFormat initialValue() {
          return new SimpleDateFormat("yyyy-MM-dd");
        }
      };

  public static void main(String... args) {
    System.out.println("First ThreadLocalCleaner context");
    try (ThreadLocalCleaner tlc = new ThreadLocalCleaner(
        ThreadLocalChangeListener.PRINTER)) {
      System.out.println(System.identityHashCode(df.get()));
      System.out.println(System.identityHashCode(df.get()));
      System.out.println(System.identityHashCode(df.get()));
    }

    System.out.println("Another ThreadLocalCleaner context");
    try (ThreadLocalCleaner tlc = new ThreadLocalCleaner(
        ThreadLocalChangeListener.PRINTER)) {
      System.out.println(System.identityHashCode(df.get()));
      System.out.println(System.identityHashCode(df.get()));
      System.out.println(System.identityHashCode(df.get()));
    }
  }
}

你的输出结果可能会包含不同的hash code值。但是请记住我在Identity Crisis Newsletter中所说的:hash code的生成算法是一个随机数字生成器。这是我的输出。注意,在try-with-resource内部,线程局部变量的值是相同的。

First ThreadLocalCleaner context
186370029
186370029
186370029
Thread Thread[main,5,main] ADDED ThreadLocal class 
    ThreadLocalCleanerExample$1 with value 
    java.text.SimpleDateFormat@f67a0200
Another ThreadLocalCleaner context
2094548358
2094548358
2094548358
Thread Thread[main,5,main] ADDED ThreadLocal class 
    ThreadLocalCleanerExample$1 with value 
    java.text.SimpleDateFormat@f67a0200

为了让这个代码使用起来更简单,我写了一个Facede。门面设计模式不是阻止用户使用直接使用子系统,而是提供一种更简单的接口来完成复杂的系统。最典型的方式是将最常用子系统作为方法提供,我们的门面包括两个方法:findAll(Thread) 和 printThreadLocals() 方法。findAll() 方法返回一个线程内部的Entry集合。

package threadcleaner;

import java.io.*;
import java.lang.ref.*;
import java.util.AbstractMap.*;
import java.util.*;
import java.util.Map.*;
import java.util.function.*;

import static threadcleaner.ThreadLocalCleaner.*;

public class ThreadLocalCleaners {
  public static Collection<Entry<ThreadLocal<?>, Object>> findAll(
      Thread thread) {
    Collection<Entry<ThreadLocal<?>, Object>> result =
        new ArrayList<>();
    BiConsumer<ThreadLocal<?>, Object> adder =
        (key, value) ->
            result.add(new SimpleImmutableEntry<>(key, value));
    forEach(thread, adder);
    return result;
  }

  public static void printThreadLocals() {
    printThreadLocals(System.out);
  }

  public static void printThreadLocals(Thread thread) {
    printThreadLocals(thread, System.out);
  }

  public static void printThreadLocals(PrintStream out) {
    printThreadLocals(Thread.currentThread(), out);
  }

  public static void printThreadLocals(Thread thread,
                                       PrintStream out) {
    out.println("Thread " + thread.getName());
    out.println("  ThreadLocals");
    printTable(thread, out);
  }

  private static void printTable(
      Thread thread, PrintStream out) {
    forEach(thread, (key, value) -> {
      out.printf("    {%s,%s", key, value);
      if (value instanceof Reference) {
        out.print("->" + ((Reference<?>) value).get());
      }
      out.println("}");
    });
  }
}

线程可以包含两个不同类型的ThreadLocal:一个是普通的,另一个是可继承的。大部分情况下,我们使用普通的那个。可继承意味着如果你从当前线程构造出一个新的线程,则所有可继承的ThreadLocals将被新线程继承过去。非常同意。我们很少这么使用。所以现在我们可以忘了这种情况,或者永远忘记。

一个使用ThreadLocalCleaner的典型场景是和ThreadPoolExecutor一起使用。我们写一个子类,覆盖 beforeExecute() 和 afterExecute() 方法。这个类比较长,因为我们不得不编写所有的构造函数。有意思的地方在最后面。

package threadcleaner;

import java.util.concurrent.*;

public class ThreadPoolExecutorExt extends ThreadPoolExecutor {
  private final ThreadLocalChangeListener listener;

  // Bunch of constructors following - you can ignore those 

  public ThreadPoolExecutorExt(
      int corePoolSize, int maximumPoolSize, long keepAliveTime,
      TimeUnit unit, BlockingQueue<Runnable> workQueue) {
    this(corePoolSize, maximumPoolSize, keepAliveTime, unit,
        workQueue, ThreadLocalChangeListener.EMPTY);
  }

  public ThreadPoolExecutorExt(
      int corePoolSize, int maximumPoolSize, long keepAliveTime,
      TimeUnit unit, BlockingQueue<Runnable> workQueue,
      ThreadFactory threadFactory) {
    this(corePoolSize, maximumPoolSize, keepAliveTime, unit,
        workQueue, threadFactory,
        ThreadLocalChangeListener.EMPTY);
  }

  public ThreadPoolExecutorExt(
      int corePoolSize, int maximumPoolSize, long keepAliveTime,
      TimeUnit unit, BlockingQueue<Runnable> workQueue,
      RejectedExecutionHandler handler) {
    this(corePoolSize, maximumPoolSize, keepAliveTime, unit,
        workQueue, handler,
        ThreadLocalChangeListener.EMPTY);
  }

  public ThreadPoolExecutorExt(
      int corePoolSize, int maximumPoolSize, long keepAliveTime,
      TimeUnit unit, BlockingQueue<Runnable> workQueue,
      ThreadFactory threadFactory,
      RejectedExecutionHandler handler) {
    this(corePoolSize, maximumPoolSize, keepAliveTime, unit,
        workQueue, threadFactory, handler,
        ThreadLocalChangeListener.EMPTY);
  }

  public ThreadPoolExecutorExt(
      int corePoolSize, int maximumPoolSize, long keepAliveTime,
      TimeUnit unit, BlockingQueue<Runnable> workQueue,
      ThreadLocalChangeListener listener) {
    super(corePoolSize, maximumPoolSize, keepAliveTime, unit,
        workQueue);
    this.listener = listener;
  }

  public ThreadPoolExecutorExt(
      int corePoolSize, int maximumPoolSize, long keepAliveTime,
      TimeUnit unit, BlockingQueue<Runnable> workQueue,
      ThreadFactory threadFactory,
      ThreadLocalChangeListener listener) {
    super(corePoolSize, maximumPoolSize, keepAliveTime, unit,
        workQueue, threadFactory);
    this.listener = listener;
  }

  public ThreadPoolExecutorExt(
      int corePoolSize, int maximumPoolSize, long keepAliveTime,
      TimeUnit unit, BlockingQueue<Runnable> workQueue,
      RejectedExecutionHandler handler,
      ThreadLocalChangeListener listener) {
    super(corePoolSize, maximumPoolSize, keepAliveTime, unit,
        workQueue, handler);
    this.listener = listener;
  }

  public ThreadPoolExecutorExt(
      int corePoolSize, int maximumPoolSize, long keepAliveTime,
      TimeUnit unit, BlockingQueue<Runnable> workQueue,
      ThreadFactory threadFactory,
      RejectedExecutionHandler handler,
      ThreadLocalChangeListener listener) {
    super(corePoolSize, maximumPoolSize, keepAliveTime, unit,
        workQueue, threadFactory, handler);
    this.listener = listener;
  }

  /* The interest bit of this class is below ... */

  private static final ThreadLocal<ThreadLocalCleaner> local =
      new ThreadLocal<>();

  protected void beforeExecute(Thread t, Runnable r) {
    assert t == Thread.currentThread();
    local.set(new ThreadLocalCleaner(listener));
  }

  protected void afterExecute(Runnable r, Throwable t) {
    ThreadLocalCleaner cleaner = local.get();
    local.remove();
    cleaner.cleanup();
  }
}

你可以像使用一个普通的ThreadPoolExecutor一样使用这个类,该类不同的地方在于,当每个Runnable执行完之后需要重置线程局部变量的状态。如果需要调试系统,你也可以获取到绑定的监听器。在我们的这个例子里,你可以看到,我们将监听器绑定到我们的增加线程局部变量的LOG上。注意,在Java 8中,java.util.logging.Logger的方法使用Supplier作为参数,这意味着我们不再需要任何代码来保证日志的性能。

import java.text.*;
import java.util.concurrent.*;
import java.util.logging.*;

public class ThreadPoolExecutorExtTest {
  private final static Logger LOG = Logger.getLogger(
      ThreadPoolExecutorExtTest.class.getName()
  );

  private static final ThreadLocal<DateFormat> df =
      new ThreadLocal<DateFormat>() {
        protected DateFormat initialValue() {
          return new SimpleDateFormat("yyyy-MM-dd");
        }
      };

  public static void main(String... args)
      throws InterruptedException {
    ThreadPoolExecutor tpe = new ThreadPoolExecutorExt(
        1, 1, 0, TimeUnit.SECONDS,
        new LinkedBlockingQueue<>(),
        (m, t, tl, v) -> {
          LOG.warning(
              () -> String.format(
                  "Thread %s %s ThreadLocal %s with value %s%n",
                  t, m, tl.getClass(), v)
          );
        }
    );

    for (int i = 0; i < 10; i++) {
      tpe.submit(() ->
          System.out.println(System.identityHashCode(df.get())));
      Thread.sleep(1000);
    }
    tpe.shutdown();
  }
}

我机器的输出结果如下:

914524658
May 23, 2015 9:28:50 PM ThreadPoolExecutorExtTest lambda$main$1
WARNING: Thread Thread[pool-1-thread-1,5,main] 
    ADDED ThreadLocal class ThreadPoolExecutorExtTest$1 
    with value java.text.SimpleDateFormat@f67a0200

957671209
May 23, 2015 9:28:51 PM ThreadPoolExecutorExtTest lambda$main$1
WARNING: Thread Thread[pool-1-thread-1,5,main] 
    ADDED ThreadLocal class ThreadPoolExecutorExtTest$1 
    with value java.text.SimpleDateFormat@f67a0200

466968587
May 23, 2015 9:28:52 PM ThreadPoolExecutorExtTest lambda$main$1
WARNING: Thread Thread[pool-1-thread-1,5,main] 
    ADDED ThreadLocal class ThreadPoolExecutorExtTest$1 
    with value java.text.SimpleDateFormat@f67a0200

现在,这段代码还没有在生产服务器上经过严格的考验,所以请谨慎使用。非常感谢你的阅读和支持。我真的非常感激。

原文地址:https://www.cnblogs.com/firstdream/p/7886674.html