基于SpringDataRedis实现一个分布式锁工具类

基础依赖

 <dependency>
       <groupId>org.springframework.boot</groupId>
       <artifactId>spring-boot-starter-data-redis</artifactId>
       <version>2.x.x.RELEASE</version>
 </dependency>

核心类 DRedisLock

package com.idanchuang.component.redis.util.task;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;

import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;

/**
 * 基于Redis的分布式锁(线程内可重入)
 * @author yjy
 * @date 2019/11/27 11:07
 **/
public class DRedisLock implements Lock {

    private static final Logger log = LoggerFactory.getLogger(DRedisLock.class);

    /** 默认的锁超时时间 */
    public final static long DEFAULT_TIMEOUT = 30000L;
    /** 锁key前缀 */
    public final static String LOCK_PREFIX = "D_LOCK_";
    /** 默认的获取锁超时时间 */
    public final static long DEFAULT_TRY_LOCK_TIMEOUT = 10000L;
    /** 等待锁时, 自旋尝试的周期, 默认10毫秒 */
    public final static long DEFAULT_LOOP_INTERVAL = 10L;

    /** 序列值, 用于确保锁value的唯一性 */
    private static AtomicLong SERIAL_NUM;
    /** 最大序列值 */
    private static long MAX_SERIAL;
    /** 本机host */
    private static String CURRENT_HOST;

    /** StringRedisTemplate */
    private final StringRedisTemplate redisTemplate;
    /** 锁Key */
    private final String lockKey;
    /** 锁超时时间(单位毫秒) */
    private final long timeout;
    /** 等待锁时, 自旋尝试的周期(单位毫秒) */
    private final long loopInterval;
    /** 主机+线程id */
    private final String hostThreadId;
    /** 锁定值 */
    private final String lockValue;
    /** 是否重入 */
    private boolean reentrant = false;
    /** 是否持有锁 */
    private boolean locked = false;

    static {
        try {
            SERIAL_NUM = new AtomicLong(0);
            MAX_SERIAL = 99999999L;
            CURRENT_HOST = InetAddress.getLocalHost().getHostAddress();
        } catch (UnknownHostException e) {
            CURRENT_HOST = UUID.randomUUID().toString();
            log.warn("DRedisLock > can not get local host, use uuid: {}", CURRENT_HOST);
        }
    }

    public DRedisLock(String lockName) {
        this(lockName, DEFAULT_TIMEOUT, DEFAULT_LOOP_INTERVAL);
    }

    public DRedisLock(String lockName, long timeout) {
        this(lockName, timeout, DEFAULT_LOOP_INTERVAL);
    }

    public DRedisLock(String lockName, long timeout, long loopInterval) {
        if (lockName == null) {
            throw new IllegalArgumentException("lockName must assigned");
        }
        this.redisTemplate = SpringUtil.getBean(StringRedisTemplate.class);
        this.lockKey = LOCK_PREFIX + lockName;
        this.timeout = timeout;
        this.loopInterval = loopInterval;
        this.hostThreadId = CURRENT_HOST + ":" + Thread.currentThread().getId();
        this.lockValue = this.hostThreadId + ":" + getNextSerial();
    }

    /**
     * 获取锁, 如果锁被持有, 将一直等待, 直到超出默认的的DEFAULT_TRY_LOCK_TIMEOUT
     */
    @Override
    public void lock() {
        try {
            if (!tryLock(DEFAULT_TRY_LOCK_TIMEOUT, TimeUnit.MILLISECONDS)) {
                throw new RuntimeException("try lock timeout, lockKey: " + this.lockKey);
            }
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 尝试获取锁, 如果锁被持有, 则等待相应的时间(等待锁时可被中断)
     * @throws InterruptedException 被中断等待
     */
    @Override
    public void lockInterruptibly() throws InterruptedException {
        if (!tryLock(DEFAULT_TRY_LOCK_TIMEOUT, TimeUnit.MILLISECONDS, true)) {
            throw new RuntimeException("try lock timeout, lockKey: " + this.lockKey);
        }
    }

    /**
     * 尝试获取锁, 只会立即获取一次, 如果锁被占用, 则返回false, 获取成功则返回true
     * @return 是否成功获取锁
     */
    @Override
    public boolean tryLock() {
        try {
            Boolean success = setIfAbsent(this.lockKey, this.lockValue, this.timeout / 1000);
            if (success != null && success) {
                locked = true;
                log.debug("Lock success, lockKey: {}, lockValue: {}", this.lockKey, this.lockValue);
                return true;
            } else {
                // 如果持有锁的是当前线程, 则重入
                String script = "local val,ttl=ARGV[1],ARGV[2] ";
                script += "if redis.call('EXISTS', KEYS[1])==1 then local curValue = redis.call('GET', KEYS[1]) if string.find(curValue, val)==1 then local curTtl = redis.call('TTL', KEYS[1]) redis.call('EXPIRE', KEYS[1], curTtl + ttl) return true else return false end else return false end";
                DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
                redisScript.setResultType(Boolean.class);
                redisScript.setScriptText(script);
                List<String> keys = new ArrayList<>();
                keys.add(this.lockKey);
                success = redisTemplate.execute(redisScript, keys, this.hostThreadId, String.valueOf(Math.max(this.timeout / 1000L, 1)));
                if (success != null && success) {
                    this.reentrant = true;
                    locked = true;
                    log.debug("Lock reentrant success, lockKey: {}, lockValue: {}", this.lockKey, this.lockValue);
                    return true;
                }
            }
        } catch (Exception e) {
            log.error("tryLock error, do unlock, lockKey: {}, lockValue: {}", this.lockKey, lockValue, e);
            unlock();
        }
        return false;
    }

    /**
     * 使用lua脚本的方式实现setIfAbsent, 因为当业务应用使用了redisson时, 直接使用template的setIfAbsent返回值为null
     * @param key key
     * @param value 值
     * @param timeoutSecond 超时时间
     * @return 是否成功设值
     */
    private Boolean setIfAbsent(String key, String value, long timeoutSecond) {
        String script = "local val,ttl=ARGV[1],ARGV[2] ";
        script += "if redis.call('EXISTS', KEYS[1])==1 then return false else redis.call('SET', KEYS[1], ARGV[1]) redis.call('EXPIRE', KEYS[1], ARGV[2]) return true end";
        DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
        redisScript.setResultType(Boolean.class);
        redisScript.setScriptText(script);
        List<String> keys = new ArrayList<>();
        keys.add(key);
        return redisTemplate.execute(redisScript, keys, value, String.valueOf(timeoutSecond));
    }

    /**
     * 尝试获取锁, 如果锁被占用, 则持续尝试获取, 直到超过指定的time时间
     * @param time 等待锁的时间
     * @param unit time的单位
     * @return 是否成功获取锁
     * @throws InterruptedException 被中断
     */
    @Override
    public boolean tryLock(long time, TimeUnit unit) throws InterruptedException {
        return tryLock(time, unit, false);
    }

    /**
     * 尝试获取锁, 如果锁被占用, 则持续尝试获取, 直到超过指定的time时间
     * @param time 等待锁的时间
     * @param unit time的单位
     * @param interruptibly 等待是否可被中断
     * @return 是否成功获取锁
     * @throws InterruptedException 被中断
     */
    private boolean tryLock(long time, TimeUnit unit, boolean interruptibly) throws InterruptedException {
        long millis = unit.convert(time, TimeUnit.MILLISECONDS);
        long current = System.currentTimeMillis();
        do {
            if (interruptibly && Thread.interrupted()) {
                throw new RuntimeException("tryLock interrupted");
            }
            if (tryLock()) {
                return true;
            }
            Thread.sleep(loopInterval);
        } while (System.currentTimeMillis() - current < millis);
        return false;
    }

    /**
     * 释放锁
     */
    @Override
    public void unlock() {
        try {
            if (!locked) {
                return;
            }
            if (this.reentrant) {
                log.debug("Unlock reentrant success, lockKey: {}, lockValue: {}", this.lockKey, this.lockValue);
                return;
            }
            // 使用lua脚本处理锁判断和释放
            String script = "if redis.call('get', KEYS[1]) == ARGV[1] then redis.call('del', KEYS[1]) return true else return false end";
            DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
            redisScript.setResultType(Boolean.class);
            redisScript.setScriptText(script);
            Boolean res = this.redisTemplate.execute(redisScript, Collections.singletonList(this.lockKey), this.lockValue);
            if (res != null && res) {
                locked = false;
                log.debug("Unlock success, lockKey: {}, lockValue: {}", this.lockKey, this.lockValue);
                return;
            }
        } catch (Exception e) {
            log.error("Unlock error", e);
        }
        log.warn("Unlock failed, lockKey: {}, lockValue: {}", this.lockKey, this.lockValue);
    }

    @Override
    public Condition newCondition() {
        throw new UnsupportedOperationException();
    }

    /**
     * @return 下一个序列值
     */
    private static synchronized long getNextSerial() {
        long serial = SERIAL_NUM.incrementAndGet();
        if (serial > MAX_SERIAL) {
            serial = serial - MAX_SERIAL;
            SERIAL_NUM.set(serial);
        }
        return serial;
    }

    public static AtomicLong getSerialNum() {
        return SERIAL_NUM;
    }

    public static long getMaxSerial() {
        return MAX_SERIAL;
    }

    public static String getCurrentHost() {
        return CURRENT_HOST;
    }

    public String getLockKey() {
        return lockKey;
    }

    public long getTimeout() {
        return timeout;
    }

    public long getLoopInterval() {
        return loopInterval;
    }

    public String getHostThreadId() {
        return hostThreadId;
    }

    public String getLockValue() {
        return lockValue;
    }

    public boolean isReentrant() {
        return reentrant;
    }

    @Override
    public String toString() {
        return "DRedisLock{" +
                "lockKey='" + lockKey + '\'' +
                ", timeout=" + timeout +
                ", loopInterval=" + loopInterval +
                ", hostThreadId='" + hostThreadId + '\'' +
                ", lockValue='" + lockValue + '\'' +
                ", reentrant=" + reentrant +
                '}';
    }
}

工具类 SpringUtil

package com.idanchuang.component.core.util;

import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;

@Component
public class SpringUtil implements ApplicationContextAware {

    private static ApplicationContext applicationContext;

    public static <T> T getBean(Class<T> tClass) {
        checkState();
        return applicationContext.getBean(tClass);
    }

    public static <T> T getBean(String beanName) {
        checkState();
        return (T)applicationContext.getBean(beanName);
    }

    public static <T> T getBean(String beanName, Class<T> requiredType) {
        checkState();
        return applicationContext.getBean(beanName, requiredType);
    }

    private static void checkState() {
        if (SpringUtil.applicationContext == null) {
            throw new IllegalStateException("SpringUtil applicationContext is unready");
        }
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        SpringUtil.applicationContext = applicationContext;
    }
}
View Code

至此, 我们已经可以开始使用分布式锁的功能啦, 如下

DRedisLock lock = new DRedisLock("testa");
try {
    lock.lock();
    int b = a;
    a = b + 1;
    System.out.println(a);
} finally {
    lock.unlock();
}

我觉得这样使用起来也太麻烦了, 还要自己实例化lock对象来加锁和释放锁, 如果忘记释放的话问题就很大, 所以我又封装了一个 DRedisLocks 类

package com.idanchuang.component.redis.util;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.*;

import static com.idanchuang.component.redis.util.DRedisLock.*;

/**
 * 基于Redis的分布式锁
 * @author yjy
 * @date 2019/11/27 11:07
 **/
public class DRedisLocks {

    private static final Logger log = LoggerFactory.getLogger(DRedisLocks.class);

    /**
     * 执行分布式同步代码块
     * @param lockName 锁名称
     * @param runnable 要执行的代码块
     */
    public static void runWithLock(String lockName, Runnable runnable) {
        runWithLock(lockName, DEFAULT_TRY_LOCK_TIMEOUT, DEFAULT_TIMEOUT, DEFAULT_LOOP_INTERVAL, runnable);
    }

    /**
     * 执行分布式同步代码块
     * @param lockName 锁名称
     * @param callable 要执行的代码块
     * @param <V> 返回类型
     * @return 执行结果
     */
    public static <V> V runWithLock(String lockName, Callable<V> callable) {
        return runWithLock(lockName, DEFAULT_TRY_LOCK_TIMEOUT, DEFAULT_TIMEOUT, DEFAULT_LOOP_INTERVAL, callable);
    }

    /**
     * 执行分布式同步代码块
     * @param lockName 锁名称
     * @param tryTimeout 获取锁的超时时间
     * @param runnable 要执行的代码块
     */
    public static void runWithLock(String lockName, long tryTimeout, Runnable runnable) {
        runWithLock(lockName, tryTimeout, DEFAULT_TIMEOUT, DEFAULT_LOOP_INTERVAL, runnable);
    }

    /**
     * 执行分布式同步代码块
     * @param lockName 锁名称
     * @param tryTimeout 获取锁的超时时间
     * @param callable 要执行的代码块
     * @param <V> 返回类型
     * @return 执行结果
     */
    public static <V> V runWithLock(String lockName, long tryTimeout, Callable<V> callable) {
        return runWithLock(lockName, tryTimeout, DEFAULT_TIMEOUT, DEFAULT_LOOP_INTERVAL, callable);
    }

    /**
     * 执行分布式同步代码块
     * @param lockName 锁名称
     * @param tryTimeout 获取锁的超时时间
     * @param lockTimeout 持有锁的超时时间
     * @param runnable 要执行的代码块
     */
    public static void runWithLock(String lockName, long tryTimeout, long lockTimeout, Runnable runnable) {
        runWithLock(lockName, tryTimeout, lockTimeout, DEFAULT_LOOP_INTERVAL, runnable);
    }

    /**
     * 执行分布式同步代码块
     * @param lockName 锁名称
     * @param tryTimeout 获取锁的超时时间
     * @param lockTimeout 持有锁的超时时间
     * @param callable 要执行的代码块
     * @param <V> 返回类型
     * @return 执行结果
     */
    public static <V> V runWithLock(String lockName, long tryTimeout, long lockTimeout, Callable<V> callable) {
        return runWithLock(lockName, tryTimeout, lockTimeout, DEFAULT_LOOP_INTERVAL, callable);
    }

    /**
     * 执行分布式同步代码块
     * @param lockName 锁名称
     * @param tryTimeout 获取锁的超时时间
     * @param lockTimeout 持有锁的超时时间
     * @param loopInterval 自旋获取锁间隔
     * @param runnable 要执行的代码块
     */
    public static void runWithLock(String lockName, long tryTimeout, long lockTimeout, long loopInterval, Runnable runnable) {
        Callable<Void> callable = () -> {
            runnable.run();
            return null;
        };
       runWithLock(lockName, tryTimeout, lockTimeout, loopInterval, callable);
    }

    /**
     * 执行分布式同步代码块
     * @param lockName 锁名称
     * @param tryTimeout 获取锁的超时时间
     * @param lockTimeout 持有锁的超时时间
     * @param loopInterval 自旋获取锁间隔
     * @param callable 要执行的代码块
     * @param <V> 返回类型
     * @return 执行结果
     */
    public static <V> V runWithLock(String lockName, long tryTimeout, long lockTimeout, long loopInterval, Callable<V> callable) {
        DRedisLock lock = new DRedisLock(lockName, lockTimeout, loopInterval);
        log.debug("Init DRedisLock > {}", lock);
        try {
            if (lock.tryLock(tryTimeout, TimeUnit.MILLISECONDS)) {
                log.debug("Lock successful, lockName: {}", lockName);
                return callable.call();
            }
            throw new RuntimeException("Get redisLock failed, lockName: " + lockName);
        } catch (RuntimeException e) {
            throw e;
        } catch (Exception e) {
            throw new RuntimeException(e);
        } finally {
            lock.unlock();
            log.debug("Unlock successful, lockName: {}", lockName);
        }
    }

}

现在我们就可以通过下面这种方式来使用分布式锁了, 而且不用自己手动加锁释放锁, 轻松了不少

DRedisLocks.runWithLock("testa", () -> {
    int b = a;
    a = b + 1;
    System.out.println(a);
});

那么针对整个方法的同步锁, 这样使用还是不够优雅, 能不能做到一个注解就实现分布式锁的能力, 答案当然是可以的, 我又新建了几个类

RedisLock 注解类

package com.idanchuang.component.redis.annotation;

import java.lang.annotation.*;

/**
 * @author yjy
 * @date 2020/5/8 9:53
 **/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Inherited
public @interface RedisLock {

    /** 锁名称 如果不指定,则为类名:方法名 */
    String value() default "";

    /** 获取锁的超时时间 ms */
    long tryTimeout() default 10000L;

    /** 持有锁的超时时间 ms */
    long lockTimeout() default 30000L;

    /** 自旋获取锁间隔 ms */
    long loopInterval() default 10L;

    /** 自定义业务key (解析后追加在锁名称中) */
    String[] keys() default {};

    /** 错误提示信息 */
    String errMessage() default "";

}

RedisLockAspect AOP配置类

package com.idanchuang.component.redis.aspect;

import com.idanchuang.component.base.exception.common.ErrorCode;
import com.idanchuang.component.base.exception.core.ExFactory;
import com.idanchuang.component.redis.annotation.RedisLock;
import com.idanchuang.component.redis.helper.BusinessKeyHelper;
import com.idanchuang.component.redis.util.DRedisLock;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

import java.lang.reflect.Method;
import java.util.concurrent.TimeUnit;

/**
 * Aspect for methods with {@link RedisLock} annotation.
 *
 * @author yjy
 */
@Aspect
@Component
public class RedisLockAspect {

    private static final Logger log = LoggerFactory.getLogger(RedisLockAspect.class);

    @Pointcut("@annotation(com.idanchuang.component.redis.annotation.RedisLock)")
    public void redisLockAnnotationPointcut() {
    }

    @Around("redisLockAnnotationPointcut()")
    public Object invokeWithRedisLock(ProceedingJoinPoint pjp) throws Throwable {
        Method originMethod = resolveMethod(pjp);
        RedisLock annotation = originMethod.getAnnotation(RedisLock.class);
        if (annotation == null) {
            // Should not go through here.
            throw new IllegalStateException("Wrong state for RedisLock annotation");
        }
        DRedisLock lock = null;
        String lockName = getName(annotation.value(), originMethod);
        lockName += BusinessKeyHelper.getKeyName(pjp, annotation.keys());
        try {
            lock = new DRedisLock(lockName, annotation.lockTimeout(), annotation.loopInterval());
            // 获取锁, 如果被占用则等待, 直到获取到锁, 或则等待超时
            if (lock.tryLock(annotation.tryTimeout(), TimeUnit.MILLISECONDS)) {
                return pjp.proceed();
            } else {
                String msg = "Get redisLock failed, lockName: " + lockName;
                log.warn(msg);
                throw ExFactory.throwWith(ErrorCode.CONFLICT, !StringUtils.isEmpty(annotation.errMessage()) ? msg : annotation.errMessage());
            }
        } finally {
            // 重点: 释放锁
            if (lock != null) {
                lock.unlock();
            }
        }
    }

    /**
     * 获取lockName前缀
     *
     * @param lockName
     * @param originMethod
     * @return java.lang.String
     * @author sxp
     * @date 2020/7/3 11:06
     */
    private String getName(String lockName, Method originMethod) {
        // if 未指定lockName, 则默认取 类名:方法名
        if (StringUtils.isEmpty(lockName)) {
            return originMethod.getDeclaringClass().getSimpleName() + ":" + originMethod.getName();
        } else {
            return lockName;
        }
    }

    private Method resolveMethod(ProceedingJoinPoint joinPoint) {
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Class<?> targetClass = joinPoint.getTarget().getClass();

        Method method = getDeclaredMethodFor(targetClass, signature.getName(),
                signature.getMethod().getParameterTypes());
        if (method == null) {
            throw new IllegalStateException("Cannot resolve target method: " + signature.getMethod().getName());
        }
        return method;
    }

    /**
     * Get declared method with provided name and parameterTypes in given class and its super classes.
     * All parameters should be valid.
     *
     * @param clazz          class where the method is located
     * @param name           method name
     * @param parameterTypes method parameter type list
     * @return resolved method, null if not found
     */
    private Method getDeclaredMethodFor(Class<?> clazz, String name, Class<?>... parameterTypes) {
        try {
            return clazz.getDeclaredMethod(name, parameterTypes);
        } catch (NoSuchMethodException e) {
            Class<?> superClass = clazz.getSuperclass();
            if (superClass != null) {
                return getDeclaredMethodFor(superClass, name, parameterTypes);
            }
        }
        return null;
    }

}

至此, 我们已经可以通过注解来实现接口的分布式锁能力

/**
 * @author yjy
 * @date 2020/5/8 10:21
 **/
@Component
public class LockService {

    private static int a = 0;

    @RedisLock(value = "customLockName:88888222", lockTimeout = 60000L, tryTimeout = 20000L)
    public void doSomething() {
        a ++;
        try {
            Thread.sleep(1000L);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println("a: " + a);
    }

}

以上简单的介绍了我们实现的Redis分布式锁, 其实它的功能不止介绍的这些

它还支持线程内可重入, 支持超时自动释放锁, 注解模式支持解析参数对象来作为锁资源 等等

好了, 今天就到这里吧, 拜拜~

原文地址:https://www.cnblogs.com/imyjy/p/15789410.html