Spring Boot 使用Mybatis拦截器结合Alibaba Druid解析SQL实现数据隔离

Spring Boot 使用Mybatis拦截器结合Alibaba Druid解析SQL实现数据隔离

有些时候我们需要使用到Mybatis拦截器对SQL进行后期处理,比如根据用户的角色给WHERE句子动态添加一些条件,比如如果用户是一个公司管理人员,我们希望对数据表的操作局限与该用户所对对应公司的数据。

比如一张简单的人员表:

t_person

{

​ long id; //主键

​ varchar name; //姓名

​ long com_id; //公司id

​ long dept_id; //部门id

}

当一个公司id为1的公司管理员查询公司人员数据时,我们希望sql语句应该是这样的:

SELECT * FROM t_person WHERE com_id=1

这样查询出来的数据都是本公司的数据。我们可以在每个查询的地方设置这样的条件,比如Mybatis的Example。

当然我们最好把重复的代码抽象出来,便于复用。

这里我们使用Mybaits 拦截器来做。

思路:

  1. 注入一个Mybaits拦截器
  2. 重写拦截方法,获取一个标志位是否需要进行数据隔离,是进入3步,否则跳到4
  3. 获取当前用户信息,如果是公司管理员,获取该公司管理员的公司id(假设为1),在拦截器获取拦截到SQL的语句加上WHERE com_id=1
  4. 返回处理后的SQL语句

伪代码:

@Intercepts({
    @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})
})
@Slf4j
@Component
public class SqlInterceptor implements Interceptor {

    /**
     * 拦截sql,并设置约束条件
     *
     * @param invocation
     * @return
     * @throws Throwable
     */
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
		boolean flag=XXXX.XXX();
        //获取一个标志位是否需要进行数据隔离
       	if(flag){
            //获取当前用户信息
            User ust=XXX.getCurrentUser();
            //如果是公司管理员
            if(user.isCompanyAdmin()){
                //获取该公司管理员的公司id
                Long companyId=user.getCompanyId();
                //拦截器获取拦截到SQL的语句加上限制条件
                String newSql=oldSql+"where com_id="+companyId;
                //设置经过处理的SQL语句
                invocation.XXX(newSql);
            }
        }
		//返回处理后的SQL语句
        return invocation.proceed();
    }
 }

改进:

在拦截器中我们其实不应把业务逻辑写在里面,我们想要是Mybatis拦截器获取根据是否进行数据隔离条件,如果是则获取约束的条件,把约束条件加入到WHERE的条件中。

@Override
public Object intercept(Invocation invocation) throws Throwable {
	boolean flag=XXXX.XXX();
    //获取一个标志位是否需要进行数据隔离
   	if(flag){
        //获取约束条件
        Constraint constraint=XXXX.getCurrentConstraint();
        //根据当前sql语句和约束拼接sql语句
        String newSql=concatConstraint(oldSql,constraint);
        invocation.XXX(newSql);
    }
	//返回处理后的SQL语句
    return invocation.proceed();
}

问题:

对于Constraint我们Mybaitis拦截器想要一个Map, 该Map的key是字段名,value为字段值。拦截器处理的条件为AND连接,key和value通过‘=’连接。

比如Constraint可能的结构:

Constraint

{

​ Map<String,Object> map;

}

假设map内容为["com_id":1],有如下处理代码:

    //获取约束条件
    Constraint constraint=XXXX.getCurrentConstraint();
    //根据当前sql语句和约束拼接sql语句
    String newSql=concatConstraint(oldSql,constraint);

假设oldSql内容为SELECT * FROM t_person WHERE name like '张三%'

那么上面那段代码根据上面假设的约束条件constraint处理后的SQL语句(即newSql):

SELECT * FROM t_person WHERE name like '张三%' AND com_id=1 `

再比如oldSql内容为SELECT * FROM t_person。那么处理后的SQL语句应该为SELECT * FROM t_person WHERE com_id=1

显然我们需要解析SQL语句结构根据是否有WHERE关键动态修改的条件,解决什么时候要加WHERE com_id=1还是加AND com_id=1

引入Alibaba Druid:

​ 为了解决我们改进内容中问题,我们引入阿里巴巴Druid项目,来解析SQL结构,方便我们修改条件:

<dependency>
    <groupId>com.alibaba</groupId>
    <artifactId>druid</artifactId>
    <version>1.1.20</version>
</dependency>

先看如何约束条件如何表示:

@Data
public class ConstraintContext {
    /**
     * @author Shen Zhifeng
     * @version 1.0.0
     * @class MethodType
     * @classdesc 约束语句类型
     * @date 2020/9/8 17:44
     * @see
     * @since
     */
    public enum SqlType {
        //选择语句
        SELECT,
        //更新语句
        UPDATE,
        //删除语句
        DELETE
    }

    /**
     * 构建一个约束
     *
     * @param sqlType          约束语句类型
     * @param dbType           数据库类型
     * @param constraintsMap   约束键值
     * @param constraintString 约束sql语句
     * @return * @return: null
     * @throws java.lang.IllegalArgumentException dbType为空
     * @see
     * @since
     **/
    public ConstraintContext(ConstraintContext.SqlType sqlType, String dbType, Map<String, Object> constraintsMap, String constraintString) {
        Assert.notBlank(dbType, "dbType不能为空");
        this.sqlType = sqlType;
        this.dbType = dbType;
        this.constraintsMap = constraintsMap;
        this.constraintString = constraintString;
    }

    /**
     * 构建一个约束
     *
     * @param dbType         数据库类型
     * @param constraintsMap 约束键值
     * @return * @return: null
     * @throws java.lang.IllegalArgumentException dbType为空
     * @see
     * @since
     **/
    public ConstraintContext(String dbType, Map<String, Object> constraintsMap) {
        this(null, dbType, constraintsMap, null);
    }


    /**
     * 构建一个约束
     *
     * @param dbType           数据库类型
     * @param constraintString 约束sql语句
     * @return * @return: null
     * @throws java.lang.IllegalArgumentException dbType为空
     * @see
     * @since
     **/
    public ConstraintContext(String dbType, String constraintString) {
        this(null, dbType, null, constraintString);
    }

    /**
     * 要插入约束的语句类型
     */
    private final SqlType sqlType;
    /**
     * 数据库类型
     * oracle AliOracle mysql mariadb h2 postgresql edb sqlserver jtds db2 odps phoenix
     */
    private final String dbType;
    /**
     * 约束条件
     */
    private final Map<String, Object> constraintsMap;
    /**
     * 自定义的约束条件
     */
    private final String constraintString;
}

约束条件内容主要是有数据库内容,要拦截SQL语句类型,约束条件Map(改进中提到的),约束字符串(可以自定拦截内容(直接拼接到where),比约束条件Map相对起来比较灵活)。

我们看一下如何保存约束条件,通常约束条件保存在RequestAttributes。我们使用一个帮助类来完成设置和清除约束条件:

@Slf4j
public class ConstraintHelper {

    private static final String CONTEXT_KET = "bes_constrain";

    private ConstraintHelper() {
    }

    /**
     * 设置当前的约束条件上下文
     *
     * @param context
     * @return boss.xtrain.core.mybatis.interceptor.ConstraintContext
     * @throws
     * @see
     * @since
     **/
    public static ConstraintContext setContext(ConstraintContext context) {
        Assert.notNull(context, "context为null");
        RequestAttributes attributes = RequestContextHolder.currentRequestAttributes();
        attributes.setAttribute(CONTEXT_KET, context, RequestAttributes.SCOPE_REQUEST);
        return context;
    }

    /**
     * 获取当前的约束条件上下文
     *
     * @return boss.xtrain.core.mybatis.interceptor.ConstraintContext
     * @throws
     * @see
     * @since
     **/
    public static ConstraintContext getContext() {
        try {
            RequestAttributes attributes = RequestContextHolder.currentRequestAttributes();
            return (ConstraintContext) attributes.getAttribute(CONTEXT_KET, RequestAttributes.SCOPE_REQUEST);
        } catch (IllegalStateException e) {
            //不是走Controller方法进入,调用dao层代码,RequestContextHolder获取不到RequestAttributes会抛异常,这里捕获异常防止传播
            //因此对于需要数据隔离必须要走Controller方法进入
            log.error("数据约束失效:{}", e.getMessage());
        }
        return null;
    }

    /**
     * 清除当前的约束条件上下文
     *
     * @return void
     * @throws
     * @see
     * @since
     **/
    public static void clearContext() {
        RequestAttributes attributes = RequestContextHolder.currentRequestAttributes();
        attributes.removeAttribute(CONTEXT_KET, RequestAttributes.SCOPE_REQUEST);
    }
}

最后我们看下完整拦截器代码:

@Intercepts({
    @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})
})
@Slf4j
@Component
public class SqlInterceptor implements Interceptor {

    /**
     * 拦截sql,并设置约束条件
     *
     * @param invocation
     * @return
     * @throws Throwable
     */
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        ConstraintContext constraints = ConstraintHelper.getContext();
        //没有设置约束上下文的sql不进行拦截
        if (constraints != null) {
            String constraintSql = getConstraintString(constraints.getConstraintsMap(), constraints.getConstraintString());
            if (constraintSql != null) {
                StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
                MetaObject metaStatementHandler = SystemMetaObject.forObject(statementHandler);
                //获取sql
                String oldSql = String.valueOf(metaStatementHandler.getValue("delegate.boundSql.sql"));
                if (log.isDebugEnabled()) {
                    log.info("拦截器处理前的sql语句为" + oldSql);
                }
                String newSql = contactConditions(oldSql, constraintSql, constraints.getDbType(), constraints.getSqlType());
                //重新设置sql
                metaStatementHandler.setValue("delegate.boundSql.sql", newSql);
                if (log.isDebugEnabled()) {
                    log.info("经拦截器处理后的sql语句为" + newSql);
                }
            }
        }

        return invocation.proceed();
    }

    /**
     * 根据map或constraintString生成约束sql字符串,优先使用map(如果设置的话)
     * map的约束条件,默认and连接,=约束
     *
     * @param map              约束map
     * @param constraintString 约束sql
     * @return
     */
    private String getConstraintString(Map<String, Object> map, String constraintString) {
        if (map != null && !map.isEmpty()) {
            StringBuilder constraintsBuffer = new StringBuilder();
            Set<String> keys = map.keySet();
            Iterator<String> keyIter = keys.iterator();
            if (keyIter.hasNext()) {
                String key = keyIter.next();
                constraintsBuffer.append(key).append(" = " + getSqlByClass(map.get(key)));
            }
            while (keyIter.hasNext()) {
                String key = keyIter.next();
                constraintsBuffer.append(" AND ").append(key).append(" = " + getSqlByClass(map.get(key)));
            }
            return constraintsBuffer.toString();
        }
        if (!StrUtil.isBlank(constraintString)) {
            return constraintString;
        }
        return null;
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
        //nothing to do
    }

    /**
     * 根据语句类型和sqlType决定是否添加约束条件
     *
     * @param oldSql        旧sql
     * @param constraintSql 约束sql
     * @param dbType        数据库类型
     * @param sqlType       语句类型
     * @return
     */
    private static String contactConditions(String oldSql, String constraintSql, String dbType, ConstraintContext.SqlType sqlType) {
        SQLStatementParser parser = SQLParserUtils.createSQLStatementParser(oldSql, dbType);
        List<SQLStatement> stmtList = parser.parseStatementList();
        SQLStatement stmt = stmtList.get(0);
        SQLExprParser constraintsParser = SQLParserUtils.createExprParser(constraintSql, dbType);
        SQLExpr constraintsExpr = constraintsParser.expr();
        //选择语句
        boolean useSelection = (sqlType == null || sqlType == ConstraintContext.SqlType.SELECT) && stmt instanceof SQLSelectStatement;
        if (useSelection) {
            SQLSelectStatement selectStmt = (SQLSelectStatement) stmt;
            // 拿到SQLSelect
            SQLSelect sqlselect = selectStmt.getSelect();
            SQLSelectQueryBlock query = (SQLSelectQueryBlock) sqlselect.getQuery();
            SQLExpr whereExpr = query.getWhere();
            // 修改where表达式
            if (whereExpr == null) {
                query.setWhere(constraintsExpr);
            } else {
                SQLBinaryOpExpr newWhereExpr = new SQLBinaryOpExpr(whereExpr, SQLBinaryOperator.BooleanAnd, constraintsExpr);
                query.setWhere(newWhereExpr);
            }
            sqlselect.setQuery(query);
            return sqlselect.toString();
        }
        //更新语句
        boolean useUpgrade = (sqlType == null || sqlType == ConstraintContext.SqlType.UPDATE) && stmt instanceof SQLUpdateStatement;
        if (useUpgrade) {
            SQLUpdateStatement updateStatement = (SQLUpdateStatement) stmt;
            // 拿到SQLSelect
            SQLExpr whereExpr = updateStatement.getWhere();
            // 修改where表达式
            if (whereExpr == null) {
                updateStatement.setWhere(constraintsExpr);
            } else {
                SQLBinaryOpExpr newWhereExpr = new SQLBinaryOpExpr(whereExpr, SQLBinaryOperator.BooleanAnd, constraintsExpr);
                updateStatement.setWhere(newWhereExpr);
            }
            return updateStatement.toString();
        }
        //删除语句
        boolean useDeleting = (sqlType == null || sqlType == ConstraintContext.SqlType.DELETE) && stmt instanceof SQLDeleteStatement;
        if (useDeleting) {
            SQLDeleteStatement deleteStatement = (SQLDeleteStatement) stmt;
            // 拿到SQLSelect
            SQLExpr whereExpr = deleteStatement.getWhere();
            // 修改where表达式
            if (whereExpr == null) {
                deleteStatement.setWhere(constraintsExpr);
            } else {
                SQLBinaryOpExpr newWhereExpr = new SQLBinaryOpExpr(whereExpr, SQLBinaryOperator.BooleanAnd, constraintsExpr);
                deleteStatement.setWhere(newWhereExpr);
            }
            return deleteStatement.toString();
        }
        return oldSql;
    }


    /**
     * 转换java对象到sql支持的类型值
     *
     * @param value
     * @return
     */
    private static String getSqlByClass(Object value) {
        if (value instanceof Number) {
            return value + "";
        } else if (value instanceof String) {
            return "'" + value + "'";
        }

        return "'" + value.toString() + "'";
    }
}

我们为SQL语句设置where子句条件时,如果SQL语句本身没有WHERE关键字,我们就设置一个否则把原来WHERE条件拼接上我们的约束条件。解析SQL语句并结构化减少我们代码工作量。

    if (useSelection) {
        SQLSelectStatement selectStmt = (SQLSelectStatement) stmt;
        // 拿到SQLSelect
        SQLSelect sqlselect = selectStmt.getSelect();
        SQLSelectQueryBlock query = (SQLSelectQueryBlock) sqlselect.getQuery();
        SQLExpr whereExpr = query.getWhere();
        // 修改where表达式
        if (whereExpr == null) {
            query.setWhere(constraintsExpr);
        } else {
            SQLBinaryOpExpr newWhereExpr = new SQLBinaryOpExpr(whereExpr, SQLBinaryOperator.BooleanAnd, constraintsExpr);
            query.setWhere(newWhereExpr);
        }
        sqlselect.setQuery(query);
        return sqlselect.toString();
    }

使用:

比如我们有一个UserDao查询用户方法:

    public List<User> query(){
        //约束条件构建
        Map<String,Object> map=new HashMap<>();
        //获取当前用户
        SecurityUser user = SecurityContextHolder.getCurrentUser();
        //如果时超级管理员
        if(user.isSystemAdmin()){
            //没什么可以做的
        }
        //如果是公司管理员
        if(user.isCompanyAdmin()){
            //获取用户的公司id
            Long companyId = user.getCompanyId();
            //设置约束条件 相当于 com_id=${companyId}
            map.put("com_id",companyId);
        }else{
            //其他人员就无法查询了,因为0=1条件无法满足
            map.put("0",1);
        }
        ConstraintContext context = new ConstraintContext("mysql",map);
        //设置条件
        ConstraintHelper.setContext(context);

        List<User> users = mapper.selectAll();
        //清除条件,防止约束条件污染其他语句
        ConstraintHelper.clearContext();
        return users;
    }

以下代码其实可以进行复用,可以做一个切面类配和注解解决。就不在贴代码了。

public List<User> query(){
    //约束条件构建
    Map<String,Object> map=new HashMap<>();
    //获取当前用户
    SecurityUser user = SecurityContextHolder.getCurrentUser();
    //如果时超级管理员
    if(user.isSystemAdmin()){
        //没什么可以做的
    }
    //如果是公司管理员
    if(user.isCompanyAdmin()){
        //获取用户的公司id
        Long companyId = user.getCompanyId();
        //设置约束条件 相当于 com_id=${companyId}
        map.put("com_id",companyId);
    }else{
        //其他人员就无法查询了,因为0=1条件无法满足
        map.put("0",1);
    }
    ConstraintContext context = new ConstraintContext("mysql",map);
    //设置条件
    ConstraintHelper.setContext(context);
	//...
}
原文地址:https://www.cnblogs.com/Zhifeng-Shen/p/13792852.html