【手撸一个ORM】第八步、查询工具类

一、实体查询

using MyOrm.Commons;
using MyOrm.DbParameters;
using MyOrm.Expressions;
using MyOrm.Mappers;
using MyOrm.Reflections;
using MyOrm.SqlBuilder;
using System;
using System.Collections.Generic;
using System.Data;
using System.Data.SqlClient;
using System.Linq;
using System.Linq.Expressions;
using System.Text;

namespace MyOrm.Queryable
{
    public class MyQueryable<T> where T : class , new ()
    {
        private readonly string _connectionString;

        // 要查询的导航属性
        private readonly Dictionary<string, string[]> _includeProperties = new Dictionary<string, string[]>();

        // Where子句中包含导航属性
        private List<string> _whereProperties = new List<string>();

        // 导航属性的缓存
        private readonly List<MyEntity> _entityCache = new List<MyEntity>();

        // Select子句
        private readonly List<SelectResolveResult> _selectProperties = new List<SelectResolveResult>();

        // 主表信息
        private readonly MyEntity _masterEntity;

        // 查询需要的参数
        private readonly MyDbParameters _parameters = new MyDbParameters();

        // 是否已经调用过Where方法
        private bool _hasInitWhere;

        // 拼接好的where子句
        private string _where;

        // 拼接好的order by子句
        private string _orderBy;

        // 构造方法
        public MyQueryable(string connectionString)
        {
            _masterEntity = MyEntityContainer.Get(typeof(T));
            _connectionString = connectionString;
        }

        #region Include
        public MyQueryable<T> Include<TProperty>(Expression<Func<T, TProperty>> expression) where TProperty : IEntity
        {
            if (expression.Body.NodeType == ExpressionType.MemberAccess)
            {
                var memberExpr = (MemberExpression)expression.Body;
                if (memberExpr.Expression != null &&
                    memberExpr.Expression.NodeType == ExpressionType.Parameter &&
                    memberExpr.Member.GetType().IsClass)
                {
                    _includeProperties.TryAdd(memberExpr.Member.Name, new string[]{});
                }
            }

            return this;
        }

        public MyQueryable<T> Include<TProperty>(
            Expression<Func<T, TProperty>> property, 
            Expression<Func<TProperty, object>> fields) where TProperty : IEntity
        {
            if (property.Body.NodeType == ExpressionType.MemberAccess)
            {
                var visitor = new ObjectMemberVisitor();
                visitor.Visit(property);
                var member = visitor.GetPropertyList().First();

                visitor.Clear();
                visitor.Visit(fields);
                var fieldList = visitor.GetPropertyList();

                _includeProperties.TryAdd(member, fieldList.ToArray());
            }

            return this;
        }

        public MyQueryable<T> Include(string navPropertyName)
        {
            var property = _masterEntity.Properties.Single(p => p.Name == navPropertyName);
            if (property != null)
            {
                if (property.JoinAble)
                {
                    _includeProperties.Add(property.Name, new string[]{});
                }
            }
            return this;
        }

        public MyQueryable<T> Include(string navPropertyName, string[] fields)
        {
            var property = _masterEntity.Properties.Single(p => p.Name == navPropertyName);
            if (property != null)
            {
                if (property.JoinAble)
                {
                    _includeProperties.Add(property.Name, fields);
                }
            }
            return this;
        }
        #endregion

        #region Where
        public MyQueryable<T> Where(Expression<Func<T, bool>> expr)
        {
            if (_hasInitWhere)
            {
                throw new ArgumentException("每个查询只能调用一次Where方法");
            }
            _hasInitWhere = true;

            var condition = new QueryConditionResolver<T>(_masterEntity);
            var result = condition.Resolve(expr.Body);
            _where = result.Condition;
            _parameters.AddParameters(result.Parameters);
            _entityCache.AddRange(result.NavPropertyList);
            _whereProperties = result.NavPropertyList.Select(p => p.Name).ToList();

            return this;
        }
        #endregion

        #region OrderBy,ThenOrderBy
        public MyQueryable<T> OrderBy<TProperty>(Expression<Func<T, TProperty>> expression,
            MyDbOrderBy orderBy = MyDbOrderBy.Asc)
        {
            if (expression.Body.NodeType == ExpressionType.MemberAccess)
            {
                _orderBy = GetOrderByString((MemberExpression)expression.Body);
                if (orderBy == MyDbOrderBy.Desc)
                {
                    _orderBy += " DESC";
                }
            }

            return this;
        }

        public MyQueryable<T> ThenOrderBy<TProperty>(Expression<Func<T, TProperty>> expression,
            MyDbOrderBy orderBy = MyDbOrderBy.Asc)
        {
            if (string.IsNullOrWhiteSpace(_orderBy))
            {
                throw new ArgumentNullException(nameof(_orderBy), "排序字段为空,必须先调用OrderBy或OrderByDesc才能调用此方法");
            }
            if (expression.Body.NodeType == ExpressionType.MemberAccess)
            {
                _orderBy += "," + GetOrderByString((MemberExpression)expression.Body);
                if (orderBy == MyDbOrderBy.Desc)
                {
                    _orderBy += " DESC";
                }
            }

            return this;
        }
        #endregion

        #region Select

        public MySelect<TTarget> Select<TTarget>(Expression<Func<T, object>> expression)
        {
            var visitor = new SelectExpressionResolver();
            visitor.Visit(expression);
            _selectProperties.AddRange(visitor.GetPropertyList());
            return new MySelect<TTarget>(_connectionString, GetFields(), GetFrom(), _where, _parameters, _orderBy);
        }
        
        #endregion

        #region 输出
        public List<T> ToList()
        {
            var fields = GetFields();
            var from = GetFrom();

            var sqlBuilder = new SqlServerBuilder();
            var sql = sqlBuilder.Select(from, fields, _where, _orderBy);

            var visitor = new SqlDataReaderConverter<T>();
            List<T> result;
            using (var conn = new SqlConnection(_connectionString))
            {
                var command = new SqlCommand(sql, conn);
                command.Parameters.AddRange(_parameters.Parameters);
                conn.Open();
                using (var sdr = command.ExecuteReader())
                {
                    result = visitor.ConvertToEntityList(sdr);
                }
            }

            return result;
        }

        public List<T> ToPageList(int pageIndex, int pageSize, out int recordCount)
        {
            var fields = GetFields();
            var from = GetFrom();
            recordCount = 0;

            var sqlBuilder = new SqlServerBuilder();
            var sql = sqlBuilder.PagingSelect(from, fields, _where, _orderBy, pageIndex, pageSize);

            var command = new SqlCommand(sql);
            command.Parameters.AddRange(_parameters.Parameters);
            var param = new SqlParameter("@RecordCount", SqlDbType.Int) { Direction = ParameterDirection.Output };
            command.Parameters.Add(param);

            List<T> result;

            using (var conn = new SqlConnection(_connectionString))
            {
                conn.Open();
                command.Connection = conn;
                using (var sdr = command.ExecuteReader())
                {
                    var handler = new SqlDataReaderConverter<T>(_includeProperties.Select(p => p.Key).ToArray());
                    result = handler.ConvertToEntityList(sdr);
                }
            }

            recordCount = (int)param.Value;
            return result;
        }

        public T FirstOrDefault()
        {
            var fields = GetFields();
            var from = GetFrom();

            var sqlBuilder = new SqlServerBuilder();
            var sql = sqlBuilder.Select(from, fields, _where, _orderBy, 1);

            using (var conn = new SqlConnection(_connectionString))
            {
                conn.Open();
                var command = new SqlCommand(sql, conn);
                command.Parameters.AddRange(_parameters.Parameters);
                var sdr = command.ExecuteReader();

                var handler = new SqlDataReaderConverter<T>(_includeProperties.Select(p => p.Key).ToArray());
                return handler.ConvertToEntity2(sdr);
            }
        }
        #endregion

        #region 辅助方法

        /// 把要用到的导航属性的MyEntity缓存到一个List里,不需要每次都要到字典中获取
        private MyEntity GetIncludePropertyEntityInfo(Type type)
        {
            var entity = _entityCache.FirstOrDefault(e => e.Name == type.FullName);

            if (entity != null) return entity;

            entity = MyEntityContainer.Get(type);
            _entityCache.Add(entity);
            return entity;
        }

        // 获取Select子句
        public string GetFields()
        {
            if (_selectProperties.Count == 0)
            {
                var masterFields = string.Join(
                    ",",
                    _masterEntity
                        .Properties
                        .Where(p => p.IsMap)
                        .Select(p => $"[{_masterEntity.TableName}].[{p.FieldName}] AS [{p.Name}]")
                );

                if (_includeProperties.Count > 0)
                {
                    var sb = new StringBuilder(masterFields);
                    sb.Append(",");
                    var includeProperties = _includeProperties.OrderBy(i => i);

                    foreach (var property in includeProperties)
                    {
                        var prop = _masterEntity.Properties.Single(p => p.Name == property.Key);
                        var propEntity = GetIncludePropertyEntityInfo(prop.PropertyInfo.PropertyType);
                        if (property.Value.Length == 0)
                        {
                            sb.Append(
                                string.Join(",",
                                    propEntity.Properties.Where(p => p.IsMap).Select(p =>
                                        $"[{propEntity.TableName}].[{p.FieldName}] AS [{property.Key}_{p.Name}]"))
                            );
                        }
                        else
                        {
                            sb.Append(
                                string.Join(",",
                                    propEntity.Properties.Where(p =>
                                            p.IsMap && property.Value.Contains(p.Name))
                                        .Select(p =>
                                            $"[{propEntity.TableName}].[{p.FieldName}] AS [{property.Key}_{p.Name}]"))
                            );
                        }
                    }

                    return sb.ToString();
                }

                return masterFields;
            }
            else
            {
                _includeProperties.Clear();
                var sb = new StringBuilder();
                foreach (var property in _selectProperties)
                {
                    if (string.IsNullOrWhiteSpace(property.FieldName))
                    {
                        var prop = _masterEntity.Properties.Single(p => p.Name == property.PropertyName);
                        if (prop != null)
                        {
                            sb.Append($",[{_masterEntity.TableName}].[{prop.FieldName}] AS [{property.MemberName}]");
                        }
                    }
                    else
                    {
                        if (_masterEntity.Properties.Any(p => p.Name == property.PropertyName))
                        {
                            _includeProperties.Add(property.PropertyName, new string[] {});
                            var prop = _masterEntity.Properties.Single(p => p.Name == property.PropertyName);
                            var propEntity = GetIncludePropertyEntityInfo(prop.PropertyInfo.PropertyType);

                            var field = propEntity.Properties.Single(p => p.Name == property.FieldName);
                            if (field != null)
                            {
                                sb.Append(
                                    $",[{property.PropertyName}].[{field.FieldName}] AS [{property.MemberName}]");
                            }
                        }
                    }
                }

                return sb.Remove(0, 1).ToString();
            }
        }

        // 获取From子句
        public string GetFrom()
        {
            var masterTable = $"[{_masterEntity.TableName}]";
            var allJoinProperties = _includeProperties.Select(p => p.Key).Concat(_whereProperties).Distinct().ToList();

            if (allJoinProperties.Any())
            {
                var sb = new StringBuilder(masterTable);
                foreach (var property in allJoinProperties)
                {
                    var prop = _masterEntity.Properties.Single(p => p.Name == property);
                    if (prop != null)
                    {
                        var propEntity = GetIncludePropertyEntityInfo(prop.PropertyInfo.PropertyType);
                        sb.Append($" LEFT JOIN [{propEntity.TableName}] AS [{property}] ON [{_masterEntity.TableName}].[{prop.ForeignKey}]=[{propEntity.TableName}].[{propEntity.KeyColumn}]");
                    }
                }

                return sb.ToString();
            }

            return masterTable;
        }

        // 获取OrderBy子句
        private string GetOrderByString(MemberExpression expression)
        {
            expression.GetRootType(out var stack);
            if (stack.Count == 1)
            {
                var propName = stack.Pop();
                var prop = _masterEntity.Properties.Single(p => p.Name == propName);
                return $"[{_masterEntity.TableName}].[{prop.FieldName}]";
            }

            if (stack.Count == 2)
            {
                var slavePropName = stack.Pop();
                var propertyName = stack.Pop();

                var masterProp = _masterEntity.Properties.Single(p => p.Name == propertyName);
                var slaveEntity = GetIncludePropertyEntityInfo(masterProp.PropertyInfo.PropertyType);
                var slaveProperty = slaveEntity.Properties.Single(p => p.Name == slavePropName);

                return $"[{masterProp.Name}].[{slaveProperty.FieldName}]";
            }

            return string.Empty;
        }

        #endregion
    }
}

二、按需查询 Select<T>()

using MyOrm.DbParameters;
using MyOrm.Mappers;
using MyOrm.SqlBuilder;
using System.Collections.Generic;
using System.Data;
using System.Data.SqlClient;

namespace MyOrm.Queryable
{
    public class MySelect<T>
    {
        private readonly string _connectionString;
        private readonly string _fields;
        private readonly string _table;
        private readonly string _where;
        private readonly string _orderBy;
        private readonly MyDbParameters _parameters;

        public MySelect(string connectionString, string fields, string table, string where, MyDbParameters dbParameters, string orderBy)
        {
            _fields = fields;
            _table = table;
            _where = where;
            _parameters = dbParameters;
            _orderBy = orderBy;
            _connectionString = connectionString;
        }

        public List<T> ToList()
        {

            var sqlBuilder = new SqlServerBuilder();
            var sql = sqlBuilder.Select(_table, _fields, _where, _orderBy);

            var visitor = new SqlDataReaderMapper();
            List<T> result;
            using (var conn = new SqlConnection(_connectionString))
            {
                var command = new SqlCommand(sql, conn);
                command.Parameters.AddRange(_parameters.Parameters);
                conn.Open();
                using (var sdr = command.ExecuteReader())
                {
                    result = visitor.ConvertToList<T>(sdr);
                }
            }

            return result;
        }

        //public List<dynamic> DynamicList()
        //{
        //    var sqlBuilder = new SqlServerBuilder();
        //    var sql = sqlBuilder.Select(_table, _fields, _where, _orderBy);

        //    var visitor = new SqlDataReaderMapper();
        //    List<dynamic> result;
        //    using (var conn = new SqlConnection(_connectionString))
        //    {
        //        var command = new SqlCommand(sql, conn);
        //        command.Parameters.AddRange(_parameters.Parameters);
        //        conn.Open();
        //        using (var sdr = command.ExecuteReader())
        //        {
        //            result = visitor.ConvertToList(sdr);
        //        }
        //    }

        //    return result;
        //}

        public List<T> ToPageList(int pageIndex, int pageSize, out int recordCount)
        {
            recordCount = 0;

            var sqlBuilder = new SqlServerBuilder();
            var sql = sqlBuilder.PagingSelect2008(_table, _fields, _where, _orderBy, pageIndex, pageSize);

            var command = new SqlCommand(sql);
            command.Parameters.AddRange(_parameters.Parameters);
            var param = new SqlParameter("@RecordCount", SqlDbType.Int) { Direction = ParameterDirection.Output };
            command.Parameters.Add(param);

            List<T> result;

            using (var conn = new SqlConnection(_connectionString))
            {
                conn.Open();
                command.Connection = conn;
                using (var sdr = command.ExecuteReader())
                {
                    var handler = new SqlDataReaderMapper();
                    result = handler.ConvertToList<T>(sdr);
                }
            }

            recordCount = (int)param.Value;
            return result;
        }

        //public List<dynamic> ToPageListDynamic(int pageIndex, int pageSize, out int recordCount)
        //{
        //    recordCount = 0;

        //    var sqlBuilder = new SqlServerBuilder();
        //    var sql = sqlBuilder.PagingSelect2008(_table, _fields, _where, _orderBy, pageIndex, pageSize);

        //    var command = new SqlCommand(sql);
        //    command.Parameters.AddRange(_parameters.Parameters);
        //    var param = new SqlParameter("@RecordCount", SqlDbType.Int) {Direction = ParameterDirection.Output};
        //    command.Parameters.Add(param);

        //    List<dynamic> result;

        //    using (var conn = new SqlConnection(_connectionString))
        //    {
        //        conn.Open();
        //        command.Connection = conn;
        //        using (var sdr = command.ExecuteReader())
        //        {    
        //            var handler = new SqlDataReaderMapper();
        //            result = handler.ConvertToList(sdr);
        //        }
        //    }

        //    recordCount = (int) param.Value;
        //    return result;
        //}

        public T FirstOrDefault()
        {
            var sqlBuilder = new SqlServerBuilder();
            var sql = sqlBuilder.Select(_table, _fields, _where, _orderBy, 1);

            using (var conn = new SqlConnection(_connectionString))
            {
                conn.Open();
                var command = new SqlCommand(sql, conn);
                command.Parameters.AddRange(_parameters.Parameters);
                var sdr = command.ExecuteReader();

                var handler = new SqlDataReaderMapper();
                return handler.ConvertToEntity<T>(sdr);
            }
        }

        //public dynamic FirstOrDefaultDynamic()
        //{
        //    var sqlBuilder = new SqlServerBuilder();
        //    var sql = sqlBuilder.Select(_table, _fields, _where, _orderBy, 1);

        //    using (var conn = new SqlConnection(_connectionString))
        //    {
        //        conn.Open();
        //        var command = new SqlCommand(sql, conn);
        //        command.Parameters.AddRange(_parameters.Parameters);
        //        var sdr = command.ExecuteReader();

        //        var handler = new SqlDataReaderMapper();
        //        return handler.ConvertToEntity(sdr);
        //    }
        //}
    }
}
原文地址:https://www.cnblogs.com/diwu0510/p/10663456.html