shiro & jwt

环境介绍

jdk1.8 + maven + springboot2.0.0.RELEASE

依赖

<!-- shiro -->
		<dependency>
			<groupId>org.apache.shiro</groupId>
			<artifactId>shiro-spring</artifactId>
			<version>1.4.1</version>
		</dependency>

		<!-- Jwt -->
		<dependency>
			<groupId>io.jsonwebtoken</groupId>
			<artifactId>jjwt</artifactId>
			<version>0.9.1</version>
		</dependency>
		<dependency>
			<groupId>com.auth0</groupId>
			<artifactId>java-jwt</artifactId>
			<version>3.11.0</version>
		</dependency>

配置类

  • 设置Filter、realm
  • 关闭session
  • 配置拦截规则
@Configuration
public class ShiroConfig {

    @Bean
    public ShiroFilterFactoryBean shiroFilterFactoryBean(SecurityManager securityManager, ShiroFilterChainManager shiroFilterChainManager) {
        ShiroFilterFactoryBean shiroFilterFactoryBean = new ShiroFilterFactoryBean();
        shiroFilterFactoryBean.setSecurityManager(securityManager);

        // 设置过滤器链
        shiroFilterFactoryBean.setFilters(shiroFilterChainManager.initFilterMap());
        // 请求拦截
        shiroFilterFactoryBean.setFilterChainDefinitionMap(shiroFilterChainManager.initFilterChainDefinitionMap());

        return shiroFilterFactoryBean;
    }

    @Bean
    public DefaultWebSecurityManager securityManager(ShiroRealmManager shiroRealmManager) {
        DefaultWebSecurityManager securityManager = new DefaultWebSecurityManager();
        // setAuthenticator在setRealms之前,否则会重置securityManager的Authenticator
        securityManager.setAuthenticator(new AModularRealmAuthenticator());

        // 关闭shiro dao
        // 关闭shiro session 存储
        DefaultSubjectDAO subjectDAO = (DefaultSubjectDAO) securityManager.getSubjectDAO();
        DefaultSessionStorageEvaluator evaluator = (DefaultSessionStorageEvaluator) subjectDAO.getSessionStorageEvaluator();
        JwtDefaultSubjectFactory jwtDefaultSubjectFactory = new JwtDefaultSubjectFactory(evaluator);
        securityManager.setSubjectFactory(jwtDefaultSubjectFactory);

        // 设置realms
        securityManager.setRealms(shiroRealmManager.getRealms());
        SecurityUtils.setSecurityManager(securityManager);
        return securityManager;
    }
}

关闭Session

  • 关闭shiro session 且 不存储session
  • 该类在配置类中被实例化,并设置在securityManager.setSubjectFactory中
public class JwtDefaultSubjectFactory extends DefaultWebSubjectFactory {
    private final DefaultSessionStorageEvaluator sessionStorageEvaluator;

    public JwtDefaultSubjectFactory(DefaultSessionStorageEvaluator sessionStorageEvaluator) {
        this.sessionStorageEvaluator = sessionStorageEvaluator;
    }

    @Override
    public Subject createSubject(SubjectContext context) {
        // 不创建session
        context.setSessionCreationEnabled(false);
        // 不存储session(不将shiro session存储在任何地方)
        this.sessionStorageEvaluator.setSessionStorageEnabled(false);
        return super.createSubject(context);
    }
}

配置过滤器链

  • 该类在配置类中注入,将initFilterMap的结果返回给shiroFilterFactoryBean.setFilters
@Component
public class ShiroFilterChainManager {
    private final StringRedisTemplate stringRedisTemplate;

    @Autowired
    public ShiroFilterChainManager(StringRedisTemplate stringRedisTemplate) {
        this.stringRedisTemplate = stringRedisTemplate;
    }

    public Map<String, Filter> initFilterMap() {
        Map<String, Filter> filterMap = new LinkedHashMap<>();
        PasswordFilter passwordFilter = new PasswordFilter();
        passwordFilter.setRedisTemplate(stringRedisTemplate);
        JwtFilter jwtFilter = new JwtFilter();
        jwtFilter.setRedisTemplate(stringRedisTemplate);

        filterMap.put("anon", new AnonymousFilter());
        filterMap.put("logout", new LogoutFilter());
        filterMap.put("auth", passwordFilter);
        filterMap.put("jwt", jwtFilter);
        return filterMap;
    }

    public Map<String, String> initFilterChainDefinitionMap() {
        Map<String, String> filterRuleMap = new LinkedHashMap<>();
        filterRuleMap.put("/account/**", "auth");
        filterRuleMap.put("/account/logout", "logout");
        filterRuleMap.put("/device/**", "jwt");
        return filterRuleMap;
    }
}

JwtFilter

/**
 * jwt 过滤器
 */
public class JwtFilter extends AccessControlFilter {
    private StringRedisTemplate redisTemplate;

    public void setRedisTemplate(StringRedisTemplate redisTemplate) {
        this.redisTemplate = redisTemplate;
    }

    /**
     * 判断是否登录
     * 在登录的情况下会走此方法
     * @param servletRequest
     * @param servletResponse
     * @param mappedValue
     * @return 此方法返回true直接访问控制器
     * @throws Exception
     */
    @Override
    protected boolean isAccessAllowed(ServletRequest servletRequest, ServletResponse servletResponse, Object mappedValue) throws Exception {
        //发起请求的时候就需要在Header中放一个Authorization,值就是对应的Token
        HttpServletRequest request = (HttpServletRequest) servletRequest;
        Subject subject = getSubject(servletRequest, servletResponse);
        String appId = request.getHeader("appId");

        if ((null != subject && subject.isAuthenticated())) {
            AuthenticationToken jwtToken = createJwtToken(request);
            try {
                // 委托 realm 进行登录认证
                //所以这个地方最终还是调用JwtRealm进行的认证
                getSubject(servletRequest, servletResponse).login(jwtToken);
                //也就是subject.login(token)
                return true;
            } catch (AuthenticationException e) {
                // 判断是否是jwt token过期
                if (e.getMessage().equals("JWT-EXPIRED")) {
                    String refreshJwt = redisTemplate.opsForValue().get("JWT-SESSION-" + appId);
                    if (!StringUtils.isEmpty(refreshJwt)) {
                        // redis中找到key,说明刷新时间未过期,更新jwt
                        Map<String, Object> chaim = new HashMap<>();
                        JwtUtil jwtUtil = new JwtUtil();

                        chaim.put("appId", appId);
                        long refreshTokenTime = 30000L; // 30秒
                        String newJwtToken = jwtUtil.encode(appId, refreshTokenTime >> 1, chaim); // 15秒
                        redisTemplate.opsForValue().set("JWT-SESSION-" + appId, newJwtToken, refreshTokenTime, TimeUnit.MILLISECONDS);
                        ResponseEntity responseEntity = new ResponseEntity().ok(1005, "refresh jwt token").addData("jwt", newJwtToken);
                        RequestResponseUtil.responseWrite(JSON.toJSONString(responseEntity), servletResponse);
                        return false;
                    } else {
                        // jwt 过期,客户端重新登陆
                        ResponseEntity responseEntity = new ResponseEntity().error(1006, "expired jwt");
                        RequestResponseUtil.responseWrite(JSON.toJSONString(responseEntity), servletResponse);
                        return false;
                    }
                }
                // 其他jwt错误
                ResponseEntity responseEntity = new ResponseEntity().error(1007, "error jwt");
                RequestResponseUtil.responseWrite(JSON.toJSONString(responseEntity), servletResponse);
                return false;
            } catch (Exception e) {
                // 其他错误
                e.printStackTrace();
                ResponseEntity responseEntity = new ResponseEntity().error(1111, "request error");
                RequestResponseUtil.responseWrite(JSON.toJSONString(responseEntity), servletResponse);
                return false;
            }
        } else {
            // 错误请求
            ResponseEntity responseEntity = new ResponseEntity().error(1111, "request error");
            RequestResponseUtil.responseWrite(JSON.toJSONString(responseEntity), servletResponse);
            return false;
        }

    }

    /**
     * 是否是拒绝登录
     * 没有登录的情况下会走此方法
     * @param servletRequest
     * @param servletResponse
     * @return 此方法返回true直接访问控制器
     * @throws Exception
     */
    @Override
    protected boolean onAccessDenied(ServletRequest servletRequest, ServletResponse servletResponse) throws Exception {
        return false;
    }

    private AuthenticationToken createJwtToken(HttpServletRequest request) {
        Map<String, String> params = RequestResponseUtil.getRequestParameters(request);
        String authorization = params.get("Authorization");
        return new JwtToken(authorization);
    }
}

配置Realm

  • 该类在配置类中注入,将getRealms返回结果赋值给securityManager.setRealms
@Component
public class ShiroRealmManager {
    private UserProvider userProvider;
    private JwtMatcher jwtMatcher;

    @Autowired
    public ShiroRealmManager(UserProvider userProvider, JwtMatcher jwtMatcher) {
        this.userProvider = userProvider;
        this.jwtMatcher = jwtMatcher;
    }

    public Collection<Realm> getRealms() {
        LinkedList<Realm> realms = new LinkedList<>();

        // password
        PasswordRealm passwordRealm = new PasswordRealm();
        passwordRealm.setUserProvider(userProvider);
        // jwt
        JwtRealm jwtRealm = new JwtRealm();
        // 设置jwt凭证匹配类
        jwtRealm.setCredentialsMatcher(jwtMatcher);

        realms.add(passwordRealm);
        realms.add(jwtRealm);
        return realms;
    }
}

JWT Realm

  1. 继承AuthorizingRealm,重写认证授权方法
  2. 重写supports方法,标识这个Realm是专门用来验证JwtToken
public class JwtRealm  extends AuthorizingRealm {
    /*
     * 多重写一个support
     * 标识这个Realm是专门用来验证JwtToken
     * 不负责验证其他的token(UsernamePasswordToken)
     * */
    @Override
    public boolean supports(AuthenticationToken token) {
        //这个token就是从过滤器中传入的jwtToken
        return token instanceof JwtToken;
    }

    //授权
    @Override
    protected AuthorizationInfo doGetAuthorizationInfo(PrincipalCollection principals) {
        return null;
    }

    //认证
    //这个token就是从过滤器中传入的jwtToken
    @Override
    protected AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken token) throws AuthenticationException {

        String jwt = (String) token.getPrincipal();
        if (jwt == null) {
            throw new NullPointerException("jwtToken 不允许为空");
        }

        //这里返回的是类似账号密码的东西,但是jwtToken都是jwt字符串。还需要一个该Realm的类名
        return new SimpleAuthenticationInfo(jwt,jwt,this.getName());

    }
}

多Realm认证的实现

  • 多Realm认证默认ModularRealmAuthenticator实现,获取Realm数量来选择单Realm认证或多Realm认证
    // ModularRealmAuthenticator默认实现,判断realm的个数,如果realm为一个,则进行单realm认证;如果realm为多个,则进行多realm认证,实现了AbstractAuthenticator的方法
    protected AuthenticationInfo doAuthenticate(AuthenticationToken authenticationToken) throws AuthenticationException {
        this.assertRealmsConfigured();
        Collection<Realm> realms = this.getRealms();
        return realms.size() == 1 ? this.doSingleRealmAuthentication((Realm)realms.iterator().next(), authenticationToken) : this.doMultiRealmAuthentication(realms, authenticationToken);
    }

继承ModularRealmAuthenticator,过滤Jwt的Realm认证

  • 继承ModularRealmAuthenticator,重写doAuthenticate
  • 获取Realm集合,并排除掉JwtReaml(这里调用了JwtReaml的supports方法)
  • 对筛选后的Realm数量进行判断
/**
 *  对多个Ream进行认证的实现
 */
public class AModularRealmAuthenticator extends ModularRealmAuthenticator {

    /**
     *
     * 重写多Realm认证方法
     * 获取Realm集合,排除掉JwtRealm的验证(JwtRealm专门用来验证JwtToken,不负责验证其他的token,JwtRealm重写了supports方法)
     *
     * @param authenticationToken
     * @return
     * @throws AuthenticationException
     */
    @Override
    protected AuthenticationInfo doAuthenticate(AuthenticationToken authenticationToken) throws AuthenticationException {

        assertRealmsConfigured(); // 判断realm是否为空
        List<Realm> realms = this.getRealms()
                .stream()
                .filter(realm -> {
                    return realm.supports(authenticationToken);
                })
                .collect(toList());

        // 判断realm的个数,如果realm为一个,则进行单realm认证;如果realm为多个,则进行多realm认证,实现了AbstractAuthenticator的方法
        return realms.size() == 1 ? this.doSingleRealmAuthentication(realms.iterator().next(), authenticationToken) : this.doMultiRealmAuthentication(realms, authenticationToken);
    }
}

凭证匹配类

  • 实现CredentialsMatcher接口,重写doCredentialsMatch
  • 该类设置在JwtReam中,方法doCredentialsMatch在Realm认证时调用,重写了异常信息,用于判断错误类型

/**
 *  凭证匹配类,在创建Realm时设置
 */
@Component
public class JwtMatcher implements CredentialsMatcher {
    @Override
    public boolean doCredentialsMatch(AuthenticationToken authenticationToken, AuthenticationInfo authenticationInfo) {
        String token = (String)authenticationInfo.getCredentials();
        Claims decode;

        // 写入AuthenticationException异常信息,在Filter中判断错误类型
        try {
            decode = new JwtUtil().decode(token);
        } catch (ExpiredJwtException e) {
            throw new AuthenticationException("JWT-EXPIRED");
        } catch (Exception e) {
            throw new AuthenticationException("JWT-ERROR");
        }
        if (null == decode) {
            throw new AuthenticationException("JWT-ERROR");
        }

        return true;
    }
}

参考
https://www.jianshu.com/p/9b6eb3308294

原文地址:https://www.cnblogs.com/xiongyungang/p/14056858.html