TransmittableThreadLocal使用踩坑(主线程set,异步线程get)

背景:为了获取相关字段方便,项目里使用了TransmittableThreadLocal上下文,在异步逻辑中get值时发现并非当前请求的值,且是偶发状况(并发问题)。

发现:TransmittableThreadLocal是阿里开源的可以实现父子线程值传递的工具,其子线程必须使用TtlRunnable\TtlCallable修饰或者线程池使用TtlExecutors修饰(防止数据“污染”),如果没有使用装饰后的线程池,那么使用TransmittableThreadLocal上下文,就有可能出现线程不安全的问题。

话不多说上代码:

封装的上下文,成员变量RequestHeader

package org.example.ttl.threadLocal;

import com.alibaba.ttl.TransmittableThreadLocal;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
import org.apache.commons.lang3.ObjectUtils;

/**
 * description:
 * author: JohnsonLiu
 * create at:  2021/12/24  23:19
 */
@Data
@AllArgsConstructor
@NoArgsConstructor
public class RequestContext {

    private static final ThreadLocal<RequestContext> transmittableThreadLocal = new TransmittableThreadLocal();


    private static final RequestContext INSTANCE = new RequestContext();
    private RequestHeader requestHeader;


    public static void create(RequestHeader requestHeader) {
        transmittableThreadLocal.set(new RequestContext(requestHeader));
    }

    public static RequestContext current() {
        return ObjectUtils.defaultIfNull(transmittableThreadLocal.get(), INSTANCE);
    }

    public static RequestHeader get() {
        return current().getRequestHeader();
    }

    public static void remove() {
        transmittableThreadLocal.set(null);
    }


    @Data
    @AllArgsConstructor
    @NoArgsConstructor
    @ToString
    static class RequestHeader {
        private String requestUrl;
        private String requestType;
    }
}

获取上下文内容的case:


package org.example.ttl.threadLocal;

import com.alibaba.ttl.threadpool.TtlExecutors;

import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/**
* description: TransmittableThreadLocal正确使用
* author: JohnsonLiu
* create at: 2021/12/24 22:24
* 验证结论:
* 1.线程池必须使用TtlExecutors修饰,或者Runnable\Callable必须使用TtlRunnable\TtlCallable修饰
* ---->原因:子线程复用,子线程拥有的上下文内容会对下次使用造成“污染”,而修饰后的子线程在执行run方法后会进行“回放”,防止污染
*/
public class TransmittableThreadLocalCase2 {

// 为达到线程100%复用便于测试,线程池核心数1

private static final Executor TTL_EXECUTOR = TtlExecutors.getTtlExecutor(new ThreadPoolExecutor(1, 1, 1000, TimeUnit.MICROSECONDS, new LinkedBlockingQueue<>(1000)));

// 如果使用一般的线程池或者Runnable\Callable时,会存在线程“污染”,比如线程池中线程会复用,复用的线程会“污染”该线程执行下一次任务
private static final Executor EXECUTOR = new ThreadPoolExecutor(1, 1, 1000, TimeUnit.MICROSECONDS, new LinkedBlockingQueue<>(1000));


public static void main(String[] args) {

RequestContext.create(new RequestContext.RequestHeader("url", "get"));
System.out.println(Thread.currentThread().getName() + " 子线程(rm之前 同步):" + RequestContext.get());
// 模拟另一个线程修改上下文内容
EXECUTOR.execute(() -> {
RequestContext.create(new RequestContext.RequestHeader("url", "put"));
});

// 保证上面子线程修改成功
try {
Thread.sleep(2000);
} catch (InterruptedException e) {
e.printStackTrace();
}

// 异步获取上下文内容
TTL_EXECUTOR.execute(() -> {
try {
Thread.sleep(2000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(Thread.currentThread().getName() + " 子线程(rm之前 异步):" + RequestContext.get());
});

// 主线程修改上下文内容
RequestContext.create(new RequestContext.RequestHeader("url", "post"));
System.out.println(Thread.currentThread().getName() + " 子线程(rm之前 同步<reCreate>):" + RequestContext.get());

// 主线程remove
RequestContext.remove();

// 子线程获取remove后的上下文内容
TTL_EXECUTOR.execute(() -> {
try {
Thread.sleep(3000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(Thread.currentThread().getName() + " 子线程(rm之后 异步):" + RequestContext.get());
});
}
}
 
使用一般线程池结果:
使用修饰后的线程池结果:

这种问题的解决办法:

如果大家跟我一样存在这样的使用,那么也会低概率存在这样的问题,正确的使用方式是:

子线程必须使用TtlRunnable\TtlCallable修饰或者线程池使用TtlExecutors修饰,这一点很容易被遗漏,比如上下文和异步逻辑不是同一个人开发的,那么异步逻辑的开发者就很可能直接在异步逻辑中使用上下文,而忽略装饰线程池,造成线程复用时的“数据污染”。

另外还有一种不同于上面的上下文用法,同样使用不当也会存在线程安全问题:

上代码样例

package org.example.ttl.threadLocal;

import com.alibaba.ttl.TransmittableThreadLocal;

import java.util.LinkedHashMap;
import java.util.Map;

/**
 * description: TransmittableThreadLocal正确使用
 * author: JohnsonLiu
 * create at:  2021/12/24  23:19
 */
public class ServiceContext {

    private static final ThreadLocal<Map<Integer, Integer>> transmittableThreadLocal = new TransmittableThreadLocal() {
        /**
         * 如果使用的是TtlExecutors装饰的线程池或者TtlRunnable、TtlCallable装饰的任务
         * 重写copy方法且重新赋值给新的LinkedHashMap,不然会导致父子线程都是持有同一个引用,只要有修改取值都会变化。引用值线程不安全
         * parentValue是父线程执行子任务那个时刻的快照值,后续父线程再次set值也不会影响子线程get,因为已经不是同一个引用
         * @param parentValue
         * @return
         */
        @Override
        public Object copy(Object parentValue) {
            if (parentValue instanceof Map) {
                System.out.println("copy");
                return new LinkedHashMap<Integer, Integer>((Map) parentValue);
            }
            return null;
        }

        /**
         * 如果使用普通线程池执行异步任务,重写childValue即可实现子线程获取的是父线程执行任务那个时刻的快照值,重新赋值给新的LinkedHashMap,父线程修改不会影响子线程(非共享)
         * 但是如果使用的是TtlExecutors装饰的线程池或者TtlRunnable、TtlCallable装饰的任务,此时就会变成引用共享,必须得重写copy方法才能实现非共享
         * @param parentValue
         * @return
         */
        @Override
        protected Object childValue(Object parentValue) {
            if (parentValue instanceof Map) {
                System.out.println("childValue");
                return new LinkedHashMap<Integer, Integer>((Map) parentValue);
            }
            return null;
        }

        /**
         * 初始化,每次get时都会进行初始化
         * @return
         */
        @Override
        protected Object initialValue() {
            System.out.println("initialValue");
            return new LinkedHashMap<Integer, Integer>();
        }
    };

    public static void set(Integer key, Integer value) {
        transmittableThreadLocal.get().put(key, value);
    }

    public static Map<Integer, Integer> get() {
        return transmittableThreadLocal.get();
    }

    public static void remove() {
        transmittableThreadLocal.remove();
    }
}

使用case:

package org.example.ttl.threadLocal;

import com.alibaba.ttl.threadpool.TtlExecutors;

import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/**
 * description: TransmittableThreadLocal正确使用
 * author: JohnsonLiu
 * create at:  2021/12/24  22:24
 */
public class TransmittableThreadLocalCase {


    private static final Executor executor = TtlExecutors.getTtlExecutor(new ThreadPoolExecutor(1, 1, 1000, TimeUnit.MICROSECONDS, new LinkedBlockingQueue<>(1000)));
//    private static final Executor executor = new ThreadPoolExecutor(1, 1, 1000, TimeUnit.MICROSECONDS, new LinkedBlockingQueue<>(1000));

    static int i = 0;

    public static void main(String[] args) {

        ServiceContext.set(++i, i);

        executor.execute(() -> {
            try {
                Thread.sleep(3000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println(Thread.currentThread().getName() + " 子线程(rm之前):" + ServiceContext.get());
        });

        ServiceContext.set(++i, i);
        ServiceContext.remove();

        executor.execute(() -> {
            try {
                Thread.sleep(3000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println(Thread.currentThread().getName() + " 子线程(rm之后):" + ServiceContext.get());
        });
    }
}

代码中已有详细的说明,大家可自行copy代码验证。

以上仅代表个人理解,如有错误之处,还请留言指教,进行更正!


原文地址:https://www.cnblogs.com/JohnsonLiu/p/15733680.html