ThreadPoolExecutor 优先级的线程池

本文使用 ThreadPoolExecutor实现一个带优先级的线程池,其实正常的实现方式是使用优先级队列(java.util.PriorityQueue / java.util.concurrent.PriorityBlockingQueue)这种方式没办法同步的获取结果, 编程上有点复杂, java.util.concurrent.ThreadPoolExecutor 可以 public <T> Future<T> submit(Callable<T> task); 使用Future.get(), 阻塞线程, 等待结果, 来实现同步调用。

public class PriorityThreadPoolExecutor extends ThreadPoolExecutor;

实现方法很简单, 继承 ThreadPoolExecutor 使用 PriorityBlockingQueue 优先级队列. PriorityBlockingQueue 有个坑就是.

Operations on this class make no guarantees about the ordering of elements with equal priority.
*如果优先级相同,不能确定顺序. *

实际测试下来的结果是, 如果优先级相同则执行顺序跟插入顺序相反, 这就尴尬了, 着还是FIFO队列吗? 官网给了解决方式.对每一个队列元素编号, 照抄就可以了. 限制就是队列历史总个数不能超过 Long 个. 实现一个Comparable 的类。

class PriorityRunnable<E extends Comparable<? super E>> implements Runnable, Comparable<PriorityRunnable<E>>;

重载线程池的添加任务的方法,追加一个参数,如果使用基类的方法, 优先级为 0 。

public void execute(Runnable command, int priority);
public <T> Future<T> submit(Callable<T> task, int priority);
public <T> Future<T> submit(Runnable task, T result, int priority);
public Future<?> submit(Runnable task, int priority);

最终代码如下:

  1 package com.springboot.study.tests.threads;
  2 
  3 /**
  4  * @Author: guodong
  5  * @Date: 2021/3/22 15:20
  6  * @Version: 1.0
  7  * @Description:
  8  */
  9 import org.slf4j.Logger;
 10 import org.slf4j.LoggerFactory;
 11 import java.util.concurrent.*;
 12 import java.util.concurrent.atomic.AtomicLong;
 13 
 14 public class PriorityThreadPoolExecutor extends ThreadPoolExecutor {
 15 
 16     private static final Logger log = LoggerFactory.getLogger(PriorityThreadPoolExecutor.class);
 17 
 18     private ThreadLocal<Integer> local = new ThreadLocal<Integer>() {
 19         @Override
 20         protected Integer initialValue() {
 21             return 0;
 22         }
 23     };
 24 
 25     public PriorityThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit) {
 26         super(corePoolSize, maximumPoolSize, keepAliveTime, unit, getWorkQueue());
 27     }
 28 
 29     public PriorityThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory) {
 30         super(corePoolSize, maximumPoolSize, keepAliveTime, unit, getWorkQueue(), threadFactory);
 31     }
 32 
 33     public PriorityThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, RejectedExecutionHandler handler) {
 34         super(corePoolSize, maximumPoolSize, keepAliveTime, unit, getWorkQueue(), handler);
 35     }
 36 
 37     public PriorityThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory, RejectedExecutionHandler handler) {
 38         super(corePoolSize, maximumPoolSize, keepAliveTime, unit, getWorkQueue(), threadFactory, handler);
 39     }
 40 
 41     protected static PriorityBlockingQueue getWorkQueue() {
 42         return new PriorityBlockingQueue();
 43     }
 44 
 45     @Override
 46     public void execute(Runnable command) {
 47         int priority = local.get();
 48         try {
 49             this.execute(command, priority);
 50         } finally {
 51             local.set(0);
 52         }
 53     }
 54 
 55     public void execute(Runnable command, int priority) {
 56         super.execute(new PriorityRunnable(command, priority));
 57     }
 58 
 59     public <T> Future<T> submit(Callable<T> task, int priority) {
 60         local.set(priority);
 61         return super.submit(task);
 62     }
 63 
 64     public <T> Future<T> submit(Runnable task, T result, int priority) {
 65         local.set(priority);
 66         return super.submit(task, result);
 67     }
 68 
 69     public Future<?> submit(Runnable task, int priority) {
 70         local.set(priority);
 71         return super.submit(task);
 72     }
 73 
 74     protected static class PriorityRunnable<E extends Comparable<? super E>> implements Runnable, Comparable<PriorityRunnable<E>> {
 75         private final static AtomicLong seq = new AtomicLong();
 76         private final long seqNum;
 77         Runnable run;
 78         private int priority;
 79 
 80         public PriorityRunnable(Runnable run, int priority) {
 81             seqNum = seq.getAndIncrement();
 82             this.run = run;
 83             this.priority = priority;
 84         }
 85 
 86         public int getPriority() {
 87             return priority;
 88         }
 89 
 90         public void setPriority(int priority) {
 91             this.priority = priority;
 92         }
 93 
 94         public Runnable getRun() {
 95             return run;
 96         }
 97 
 98         @Override
 99         public void run() {
100             this.run.run();
101         }
102 
103         @Override
104         public int compareTo(PriorityRunnable<E> other) {
105             int res = 0;
106             if (this.priority == other.priority) {
107                 if (other.run != this.run) {// ASC
108                     res = (seqNum < other.seqNum ? -1 : 1);
109                 }
110             } else {// DESC
111                 res = this.priority > other.priority ? -1 : 1;
112             }
113             return res;
114         }
115     }
116 }

下面是测试用例

package com.springboot.study.tests.threads;

/**
 * @Author: guodong
 * @Date: 2021/3/22 15:22
 * @Version: 1.0
 * @Description:
 */
import org.junit.Test;

import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

import static org.junit.Assert.*;

public class PriorityThreadPoolExecutorTest {

    @Test
    public void testDefault() throws InterruptedException, ExecutionException {
        PriorityThreadPoolExecutor pool = new PriorityThreadPoolExecutor(1, 1000, 1, TimeUnit.MINUTES);

        Future[] futures = new Future[20];
        StringBuffer buffer = new StringBuffer();
        for (int i = 0; i < futures.length; i++) {
            int index = i;
            futures[i] = pool.submit(new Callable() {
                @Override
                public Object call() throws Exception {
                    Thread.sleep(10);
                    buffer.append(index + ", ");
                    return null;
                }
            });
        }
        // 等待所有任务结束
        for (int i = 0; i < futures.length; i++) {
            futures[i].get();
        }
        System.out.println(buffer);
        assertEquals("0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, ", buffer.toString());
    }

    @Test
    public void testSamePriority() throws InterruptedException, ExecutionException {
        PriorityThreadPoolExecutor pool = new PriorityThreadPoolExecutor(1, 1000, 1, TimeUnit.MINUTES);

        Future[] futures = new Future[10];
        StringBuffer buffer = new StringBuffer();
        for (int i = 0; i < futures.length; i++) {
            futures[i] = pool.submit(new TenSecondTask(i, 1, buffer), 1);
        }
        // 等待所有任务结束
        for (int i = 0; i < futures.length; i++) {
            futures[i].get();
        }
        System.out.println(buffer);
        assertEquals("01@00, 01@01, 01@02, 01@03, 01@04, 01@05, 01@06, 01@07, 01@08, 01@09, ", buffer.toString());
    }

    @Test
    public void testRandomPriority() throws InterruptedException, ExecutionException {
        PriorityThreadPoolExecutor pool = new PriorityThreadPoolExecutor(1, 1000, 1, TimeUnit.MINUTES);

        Future[] futures = new Future[20];
        StringBuffer buffer = new StringBuffer();
        for (int i = 0; i < futures.length; i++) {
            int r = (int) (Math.random() * 100);
            futures[i] = pool.submit(new TenSecondTask(i, r, buffer), r);
        }
        // 等待所有任务结束
        for (int i = 0; i < futures.length; i++) {
            futures[i].get();
        }

        buffer.append("01@00");
        System.out.println(buffer);
        String[] split = buffer.toString().split(", ");
        // 从 2 开始, 因为前面的任务可能已经开始
        for (int i = 2; i < split.length - 1; i++) {
            String s = split[i].split("@")[0];
            assertTrue(Integer.valueOf(s) >= Integer.valueOf(split[i + 1].split("@")[0]));
        }
    }

    public static class TenSecondTask<T> implements Callable<T> {
        private StringBuffer buffer;
        int index;
        int priority;

        public TenSecondTask(int index, int priority, StringBuffer buffer) {
            this.index = index;
            this.priority = priority;
            this.buffer = buffer;
        }

        @Override
        public T call() throws Exception {
            Thread.sleep(10);
            buffer.append(String.format("%02d@%02d", this.priority, index)).append(", ");
            return null;
        }
    }
}
郭慕荣博客园
原文地址:https://www.cnblogs.com/jelly12345/p/14566393.html