基于 @SelectProvider 注解实现无侵入的通用Dao

基于 @SelectProvider 注解实现无侵入的通用Dao

项目框架

基于 SpringBoot 2.x 和 mybatis-spring-boot-starter

代码设计

通用Dao

public interface BaseDao<E,I> {

    @SelectProvider(type = BaseSqlProvider.class,method = "getById")
    E getById(I id);

    @SelectProvider(type = BaseSqlProvider.class,method = "listByEntity")
    List<E> listByEntity(E e);

    @SelectProvider(type = BaseSqlProvider.class,method = "getByEntity")
    E getByEntity(E e);

    @SelectProvider(type = BaseSqlProvider.class,method = "listByLambdaQuery")
    List<E> listByLambdaQuery(GetterFunction<E,?> lambda, Object val);
    
    @SelectProvider(type = BaseSqlProvider.class,method = "getByLambdaQuery")
    List<E> getByLambdaQuery(GetterFunction<E,?> lambda, Object val);

    @SelectProvider(type = BaseSqlProvider.class,method = "listByIds")
    List<E> listByIds(Collection<I> collection);

    @InsertProvider(type = BaseSqlProvider.class,method = "insert")
    @Options(keyProperty="id",useGeneratedKeys=true)
    int insert(E e);

    @InsertProvider(type = BaseSqlProvider.class,method = "insertBatch")
    @Options(keyProperty="id",useGeneratedKeys=true)
    int insertBatch(Collection<E> list);

    @UpdateProvider(type = BaseSqlProvider.class,method = "update")
    int update(E e);

    @UpdateProvider(type = BaseSqlProvider.class,method = "updateBatch")
    int updateBatch(Collection<E> list);

    @DeleteProvider(type = BaseSqlProvider.class,method = "deleteById")
    int deleteById(I id);

    @DeleteProvider(type = BaseSqlProvider.class,method = "deleteByEntity")
    int deleteByEntity(E e);

    @DeleteProvider(type = BaseSqlProvider.class,method = "deleteByIds")
    int deleteByIds(Collection<I> list);

    @SelectProvider(type = BaseSqlProvider.class,method = "countAll")
    int countAll();

    @SelectProvider(type = BaseSqlProvider.class,method = "countByEntity")
    int countByEntity(E e);

}

通用SQL Provider

//用于缓存和返回通用SQL语句
public class BaseSqlProvider {

    private static final Map<Integer,String> sqlCache = new ConcurrentHashMap<>();

    public String getById(ProviderContext context) {
        int key = context.hashCode();
        String value = sqlCache.get(key);
        if (value==null){
            value = BaseSqlBuilder.getById(context);
            sqlCache.put(key,value);
        }
        return value;
    }

    public String getByEntity(Object object,ProviderContext context) throws Exception {
        if (object==null){
            throw new Exception("entity can not be null!");
        }
        int key = context.hashCode();
        String value = sqlCache.get(key);
        if (value==null){
            value = BaseSqlBuilder.getByEntity(object);
            sqlCache.put(key,value);
        }
        return value;
    }

    public String listByIds(Collection collection, ProviderContext context) throws Exception {
        if (collection==null || collection.size()==0){
            throw new Exception("id list can not be empty!");
        }
        int key = context.hashCode();
        String value = sqlCache.get(key);
        if (value==null){
            value = BaseSqlBuilder.listByIds(context);
            sqlCache.put(key,value);
        }
        return value;
    }

    public String listByEntity(Object object,ProviderContext context) throws Exception {
        if (object==null){
            throw new Exception("entity can not be null!");
        }
        int key = context.hashCode();
        String value = sqlCache.get(key);
        if (value==null){
            value = BaseSqlBuilder.listByEntity(object);
            sqlCache.put(key,value);
        }
        return value;
    }

    public String listByLambdaQuery(Map<String,Object> params,ProviderContext context) throws Exception {
        Object val = params.get("val");
        if (val==null){
            throw new Exception("value can not be null!");
        }
        GetterFunction lambda = (GetterFunction)params.get("lambda");
        int key = context.hashCode();
        String fieldName = lambda.getFieldName(lambda);
        String value = sqlCache.get(key+fieldName);
        if (value==null){
            value = BaseSqlBuilder.listByField(fieldName,context);
            sqlCache.put(key+fieldName,value);
        }
        return value;
    }

    public String getByLambdaQuery(Map<String,Object> params,ProviderContext context) throws Exception {
        Object val = params.get("val");
        if (val==null){
            throw new Exception("value can not be null!");
        }
        GetterFunction lambda = (GetterFunction)params.get("lambda");
        int key = context.hashCode();
        String fieldName = lambda.getFieldName(lambda);
        String value = sqlCache.get(key+fieldName);
        if (value==null){
            value = BaseSqlBuilder.getByField(fieldName,context);
            sqlCache.put(key+fieldName,value);
        }
        return value;
    }

    public String insert(Object object, ProviderContext context) throws Exception {
        if (object==null){
            throw new Exception("entity can not be null!");
        }
        int key = context.hashCode();
        String value = sqlCache.get(key);
        if (value==null){
            value = BaseSqlBuilder.insert(object);
            sqlCache.put(key,value);
        }
        return value;
    }

    public String insertBatch(Collection collection, ProviderContext context) throws Exception {
        if (collection==null || collection.size()==0){
            throw new Exception("entity list can not be empty!");
        }
        int key = context.hashCode();
        String value = sqlCache.get(key);
        if (value==null){
            value = BaseSqlBuilder.insertBatch(context);
            sqlCache.put(key,value);
        }
        return value;
    }

    public String update(Object object, ProviderContext context) throws Exception {
        if (object==null){
            throw new Exception("entity can not be null!");
        }
        int key = context.hashCode();
        String value = sqlCache.get(key);
        if (value==null){
            value = BaseSqlBuilder.update(object);
            sqlCache.put(key,value);
        }
        return value;
    }

    public String updateBatch(Collection collection, ProviderContext context) throws Exception {
        if (collection==null || collection.size()==0){
            throw new Exception("entity list can not be empty!");
        }
        int key = context.hashCode();
        String value = sqlCache.get(key);
        if (value==null){
            value = BaseSqlBuilder.updateBatch(context);
            sqlCache.put(key,value);
        }
        return value;
    }

    public String deleteById(ProviderContext context) {
        int key = context.hashCode();
        String value = sqlCache.get(key);
        if (value==null){
            value = BaseSqlBuilder.deleteById(context);
            sqlCache.put(key,value);
        }
        return value;
    }

    public String deleteByEntity(Object object,ProviderContext context) throws Exception {
        if (object==null){
            throw new Exception("entity can not be null!");
        }
        int key = context.hashCode();
        String value = sqlCache.get(key);
        if (value==null){
            value = BaseSqlBuilder.deleteByEntity(object);
            sqlCache.put(key,value);
        }
        return value;
    }

    public String deleteByIds(Collection collection, ProviderContext context) throws Exception {
        if (collection==null || collection.size()==0){
            throw new Exception("id list can not be empty!");
        }
        int key = context.hashCode();
        String value = sqlCache.get(key);
        if (value==null){
            value = BaseSqlBuilder.deleteByIds(context);
            sqlCache.put(key,value);
        }
        return value;
    }

    public String countAll(ProviderContext context) {
        int key = context.hashCode();
        String value = sqlCache.get(key);
        if (value==null){
            value = BaseSqlBuilder.countAll(context);
            sqlCache.put(key,value);
        }
        return value;
    }

    public String countByEntity(Object object,ProviderContext context) throws Exception {
        if (object==null){
            throw new Exception("entity can not be null!");
        }
        int key = context.hashCode();
        String value = sqlCache.get(key);
        if (value==null){
            value = BaseSqlBuilder.countByEntity(object);
            sqlCache.put(key,value);
        }
        return value;
    }

}

通用SQL构建类

//生成通用SQL语句
public class BaseSqlBuilder {

    public static String getById(ProviderContext context) {
        Class eClass = TableEntityMetaData.getEntityType(context);
        String tableName = TableEntityMetaData.tableName(eClass);
        List<String> fields = TableEntityMetaData.entityFields(eClass);
        List<String> columns = TableEntityMetaData.tableColumns(fields);
        return "SELECT "+String.join(",",columns)+" FROM "+tableName+" WHERE "+TableEntityMetaData.getIdColumn(eClass)+" = #{id}";
    }

    public static String listByEntity(Object object) {
        Class eClass = object.getClass();
        String tableName = TableEntityMetaData.tableName(eClass);
        List<String> fields = TableEntityMetaData.entityFields(eClass);
        List<String> columns = TableEntityMetaData.tableColumns(fields);
        StringBuilder sql = new StringBuilder("<script> SELECT ");
        sql.append(String.join(",",columns));
        sql.append(" FROM ").append(tableName);
        sql.append(" <where>");
        whereByEntity(fields,columns,sql);
        sql.append("</where></script>");
        return sql.toString();
    }

    public static String getByEntity(Object object) {
        return listByEntity(object)+" LIMIT 1";
    }

    public static String listByIds(ProviderContext context) {
        Class eClass = TableEntityMetaData.getEntityType(context);
        String tableName = TableEntityMetaData.tableName(eClass);
        List<String> fields = TableEntityMetaData.entityFields(eClass);
        List<String> columns = TableEntityMetaData.tableColumns(fields);
        StringBuilder sql = new StringBuilder("<script> SELECT ");
        sql.append(String.join(",",columns));
        sql.append(" FROM ").append(tableName);
        sql.append(" WHERE ").append(TableEntityMetaData.getIdColumn(eClass)).append(" IN ");
        sql.append("<foreach item="item" collection="list" separator="," open="(" close=")" index="index">");
        sql.append("#{item}</foreach></script>");
        return sql.toString();
    }

    public static String listByField(String fieldName, ProviderContext context) throws Exception {
        Class eClass = TableEntityMetaData.getEntityType(context);
        String tableName = TableEntityMetaData.tableName(eClass);
        List<String> fields = TableEntityMetaData.entityFields(eClass);
        if (!fields.contains(fieldName)) {
            throw new Exception("not exist column '"+fieldName+"'");
        }
        List<String> columns = TableEntityMetaData.tableColumns(fields);
        return "SELECT "+String.join(",",columns)+" FROM "+tableName+" WHERE "+TableEntityMetaData.toLowerCase(fieldName)+" = #{val}";
    }

    public static String getByField(String fieldName, ProviderContext context) throws Exception {
        return listByField(fieldName,context)+" LIMIT 1";
    }

    public static String insert(Object object) {
        Class eClass = object.getClass();
        String tableName = TableEntityMetaData.tableName(eClass);
        List<String> fields = TableEntityMetaData.entityFields(eClass);
        List<String> columns = TableEntityMetaData.tableColumns(fields);
        StringBuilder sql = new StringBuilder();
        sql.append("<script> INSERT INTO ").append(tableName);
        sql.append(" <trim prefix="(" suffix=")" suffixOverrides=",">");
        for (int i = 0; i < fields.size(); i++) {
            sql.append("<if test="").append(fields.get(i)).append(" != null">");
            sql.append(columns.get(i)).append(",").append("</if>");
        }
        sql.append("</trim><trim prefix="values (" suffix=")" suffixOverrides=",">");
        for (int i = 0; i < fields.size(); i++) {
            sql.append("<if test="").append(fields.get(i)).append(" != null">");
            sql.append("#{").append(fields.get(i)).append("},").append("</if>");
        }
        sql.append("</trim></script>");
        return sql.toString();
    }

    public static String insertBatch(ProviderContext context) {
        Class eClass = TableEntityMetaData.getEntityType(context);
        String tableName = TableEntityMetaData.tableName(eClass);
        List<String> fields = TableEntityMetaData.entityFields(eClass);
        List<String> columns = TableEntityMetaData.tableColumns(fields);
        StringBuilder sql = new StringBuilder();
        sql.append("<script> INSERT INTO ").append(tableName);
        sql.append("(").append(String.join(", ",columns)).append(") values ");
        sql.append("<foreach item="item" collection="list" separator="," open="" close="" index="index"> (");
        for (int i = 0; i < fields.size(); i++) {
            sql.append("#{item.").append(fields.get(i)).append("}");
            if (i<fields.size()-1){
                sql.append(", ");
            }
        }
        sql.append(")</foreach></script>");
        return sql.toString();
    }

    public static String update(Object object) {
        Class eClass = object.getClass();
        String tableName = TableEntityMetaData.tableName(eClass);
        List<String> fields = TableEntityMetaData.entityFields(eClass);
        List<String> columns = TableEntityMetaData.tableColumns(fields);
        StringBuilder sql = new StringBuilder("<script> UPDATE ");
        sql.append(tableName).append(" <set>");
        for (int i = 1; i < fields.size(); i++) {
            sql.append("<if test="").append(fields.get(i)).append(" != null">");
            sql.append(columns.get(i)).append(" = #{").append(fields.get(i)).append("},</if>");
        }
        sql.append("</set> WHERE ").append(TableEntityMetaData.getIdColumn(eClass));
        sql.append(" = #{").append(TableEntityMetaData.getIdField(eClass)).append("} </script>");
        return sql.toString();
    }

    public static String updateBatch(ProviderContext context) {
        Class eClass = TableEntityMetaData.getEntityType(context);
        String tableName = TableEntityMetaData.tableName(eClass);
        List<String> fields = TableEntityMetaData.entityFields(eClass);
        List<String> columns = TableEntityMetaData.tableColumns(fields);
        StringBuilder sql = new StringBuilder("<script> UPDATE ");
        sql.append(tableName).append(" <trim prefix="set" suffixOverrides=",">");
        for (int i = 1; i < fields.size(); i++) {
            sql.append("<trim prefix="").append(columns.get(i)).append(" = case" suffix="end,">");
            sql.append("<foreach collection="list" item="item" index="index">");
            sql.append("when ").append(TableEntityMetaData.getIdColumn(eClass));
            sql.append(" = #{item.").append(TableEntityMetaData.getIdField(eClass)).append("} then #{item.").append(fields.get(i)).append("}");
            sql.append("</foreach></trim>");
        }
        sql.append("</trim> WHERE ").append(TableEntityMetaData.getIdColumn(eClass)).append(" IN ");
        sql.append("<foreach collection="list" index="index" item="item" separator="," open="(" close=")">");
        sql.append("#{item.").append(TableEntityMetaData.getIdField(eClass)).append("} </foreach></script>");
        return sql.toString();
    }

    public static String deleteById(ProviderContext context) {
        Class eClass = TableEntityMetaData.getEntityType(context);
        String tableName = TableEntityMetaData.tableName(eClass);
        return "DELETE FROM "+tableName+" WHERE "+TableEntityMetaData.getIdColumn(eClass)+" = #{id}";
    }

    public static String deleteByEntity(Object object) {
        Class eClass = object.getClass();
        String tableName = TableEntityMetaData.tableName(eClass);
        List<String> fields = TableEntityMetaData.entityFields(eClass);
        List<String> columns = TableEntityMetaData.tableColumns(fields);
        StringBuilder sql = new StringBuilder("<script> DELETE FROM ");
        sql.append(tableName).append(" <where> ");
        whereByEntity(fields,columns,sql);
        sql.append("</where></script>");
        return sql.toString();
    }

    public static String deleteByIds(ProviderContext context) {
        Class eClass = TableEntityMetaData.getEntityType(context);
        String tableName = TableEntityMetaData.tableName(eClass);
        StringBuilder sql = new StringBuilder("<script> DELETE FROM ");
        sql.append(tableName).append(" WHERE ").append(TableEntityMetaData.getIdColumn(eClass)).append(" IN ");
        sql.append("<foreach item="item" collection="list" separator="," open="(" close=")" index="index">");
        sql.append("#{item}</foreach></script>");
        return sql.toString();
    }

    public static String countAll(ProviderContext context) {
        Class eClass = TableEntityMetaData.getEntityType(context);
        String tableName = TableEntityMetaData.tableName(eClass);
        return "SELECT COUNT(*) FROM "+tableName;
    }

    public static String countByEntity(Object object) {
        Class eClass = object.getClass();
        String tableName = TableEntityMetaData.tableName(eClass);
        List<String> fields = TableEntityMetaData.entityFields(eClass);
        List<String> columns = TableEntityMetaData.tableColumns(fields);
        StringBuilder sql = new StringBuilder("<script> SELECT COUNT(*) FROM ");
        sql.append(tableName).append(" <where> ");
        whereByEntity(fields,columns,sql);
        sql.append("</where></script>");
        return sql.toString();
    }

    private static void whereByEntity(List<String> fields,List<String> columns,StringBuilder sql){
        for (int i = 0; i < fields.size(); i++) {
            sql.append("<if test="").append(fields.get(i)).append(" != null">");
            sql.append("and ").append(columns.get(i));
            sql.append(" = #{").append(fields.get(i)).append("}</if>");
        }
    }
}

表实体元数据工具类

//通过ProviderContext和entity实体对象获取表和实体元数据信息
//通过实体类型获取表名和列名,但数据库和实体必须遵循下划线转驼峰规则
//即表列名必须全小写,多单词以下划线分割,实体属性必须为驼峰规则
public class TableEntityMetaData {
    public static Class getEntityType(ProviderContext context) {
        Class mClass = context.getMapperType();
        return (Class) ((ParameterizedType) (mClass.getGenericInterfaces()[0])).getActualTypeArguments()[0];
    }

    public static String getIdColumn(Class eClass){
        return  "id";
    }

    public static String getIdField(Class eClass){
        return "id";
    }

    public static String tableName(Class eClass) {
        String entityName = eClass.getSimpleName();
        return toLowerCase(entityName);
    }

    public static List<String> entityFields(Class eClass) {
        Field[] fields = eClass.getDeclaredFields();
        List<String> entityFields = new ArrayList<>(fields.length);
        for (int i = 0; i < fields.length; i++) {
            String name = fields[i].getName();
            if (name.equals(getIdField(eClass))){
                entityFields.add(0,name);
            }else {
                entityFields.add(name);
            }
        }
        return entityFields;
    }

    public static List<String> tableColumns(List<String> entityFields) {
        List<String> tableColumns =new ArrayList<>(entityFields.size());
        for (String field : entityFields) {
            tableColumns.add(toLowerCase(field));
        }
        return tableColumns;
    }

    public static String toLowerCase(String camelStr) {
        String lowerCase = camelStr.replaceAll("[A-Z]", "_$0").toLowerCase();
        if (lowerCase.startsWith("_")){
            lowerCase = lowerCase.substring(1);
        }
        return lowerCase;
    }
}

lambda query function 接口

@FunctionalInterface
public interface GetterFunction<T,R> extends Serializable,Function<T,R> {
    default String getFieldName(GetterFunction<T,?> func) {
        try {
            Method method = func.getClass().getDeclaredMethod("writeReplace");
            method.setAccessible(Boolean.TRUE);
            SerializedLambda serializedLambda = (SerializedLambda) method.invoke(func);
            String getter = serializedLambda.getImplMethodName();
            String get = "get";
            if (getter.startsWith("is")) {
                get = "is";
            }
            String fieldName = Introspector.decapitalize(getter.replace(get, ""));
            return fieldName;
        } catch (ReflectiveOperationException e) {
            throw new RuntimeException(e);
        }
    }
}

实体类

@Data//lombok
public class User {
    /**
    * 主键,自增
    */
    private Integer id;
    private String username;
    private String password;
    /**
    * 记录生成时间,默认当前时间
    */
    private Date gmtCreate;
    /**
    * 记录修改时间,默认当前时间
    */
    private Date gmtModified;
}

具体Dao

public interface UserDao extends BaseDao<User,Integer> {
}

yml mybatis配置

一定要开启下划线转驼峰设置

mybatis:
  mapper-locations: classpath*:mapper/**/*.xml
  configuration:
    map-underscore-to-camel-case: true

使用

  • 自己写的SQL可以放在resources/mapper路径里的Mapper.xml中
  • 对应dao方法则放在具体的dao中
原文地址:https://www.cnblogs.com/xiaogblog/p/14151888.html