对CodeSmith + netTiers 生成DAL的一点补充

作者:肖波

个人博客:http://blog.csdn.net/eaglet ; http://www.cnblogs.com/eaglet

2007/7 南京


版本
CodeSmith 4.0
netTiers 2.0.1

背景
        最近在项目中使用CodeSmith + netTiers 生成数据访问层DAL,感觉效果很好,减少了大量的简单重复劳动。
不过在使用过程中发现CodeSmith提供的方法不能完全满足项目需要,主要体现在两个方面:
1、 Data.DataRepository.TableProvider.GetPaged方法无法输入带参数的条件,调用前必须进行SQL 拼接,
这样可能导致SQL 注入攻击。
2、 DataRepository.Provider.ExecuteDataSet 无法分页查询

为解决以上问题,我做了如下代码对生成的DAL进行了补充。这些代码可以在DAL外部使用,也可以修改netTiers
模板,内置到DAL中。


    /// <summary>
    
/// 带参数的条件查询子句异常
    
/// </summary>

    public class ParaWhereStringException : Exception
    
{
        
public ParaWhereStringException(String message) 
            : 
base(message)
        
{

        }

    }


    
/// <summary>
    
/// 带参数的条件查询子句
    
/// </summary>

    public class ParaWhereString
    
{
        
enum T_STATE
        
{
            Idle 
= 0,
            At   
= 1,
            Str  
= 2,
        }


        T_STATE m_State;
        
int m_LastPos;
        
int m_CurPos;
        String m_WhereString;
        List
<String> m_Words = new List<string>();

        
private void Clear()
        
{
            m_State 
= T_STATE.Idle;
            m_LastPos 
= 0;
            m_CurPos 
= 0;
            m_WhereString 
= "";
            m_Words 
= new List<string>();
        }


        
private void ChangeState(T_STATE curState)
        
{
            m_State 
= curState;
            NewWord();
        }


        
private void EndWord()
        
{
            m_Words.Add(m_WhereString.Substring(m_LastPos, m_WhereString.Length 
- m_LastPos));
        }


        
private void NewWord()
        
{
            m_Words.Add(m_WhereString.Substring(m_LastPos, m_CurPos 
- m_LastPos));
            m_LastPos 
= m_CurPos;
        }


        
private void StateMachine(char ch)
        
{
            
switch (m_State)
            
{
                
case T_STATE.Idle:
                    
if (ch == '@')
                    
{
                        ChangeState(T_STATE.At);
                    }

                    
else if (ch == '\'')
                    {
                        ChangeState(T_STATE.Str);
                    }


                    
break;
                
case T_STATE.At:
                    
if ((ch >= 'a' && ch <= 'z'|| (ch >= 'A' && ch <= 'Z'|| ch == '_')
                    
{
                        
break;
                    }


                    
if (ch >= '0' && ch <= '9' && m_CurPos - m_LastPos > 1)
                    
{
                        
break;
                    }


                    
if (ch == '\'')
                    {
                        ChangeState(T_STATE.Str);
                    }

                    
else
                    
{
                        ChangeState(T_STATE.Idle);
                    }


                    
break;

                
case T_STATE.Str:
                    
if (ch == '\'')
                    {
                        m_CurPos
++;

                        
if (m_WhereString[m_CurPos] == '\'')
                        {
                            
break;
                        }

                        
else
                        
{
                            ChangeState(T_STATE.Idle);
                        }

                    }

                    
break;
            }


            
if (m_CurPos == m_WhereString.Length - 1)
            
{
                
//无论任何状态,只要到了最后一个字符,结束状态机
                EndWord();
                
return;
            }

        }


        
private void SplitWhereString(String whereString)
        
{
            System.Diagnostics.Debug.Assert(whereString 
!= null);

            m_State 
= T_STATE.Idle;

            m_LastPos 
= 0;
            m_CurPos 
= 0;

            
while (m_CurPos < whereString.Length)
            
{
                StateMachine(whereString[m_CurPos]);
                m_CurPos
++;
            }

        }


        
private String GetParaValue(String paraName, object value)
        
{
            
if ((value is int|| (value is uint||
                (value 
is short|| (value is ushort||
                (value 
is sbyte|| (value is byte||
                (value 
is long|| (value is ulong||
                (value 
is float|| (value is double)
                )
            
{
                
return value.ToString();
            }


            
if ((value is string|| (value is char))
            
{
                
return "'" + value.ToString().Replace("'""''"+ "'";
            }


            
if (value is DateTime)
            
{
                DateTime d 
= (DateTime)value;

                
return "'" + d.ToString("yyyy-MM-dd HH:mm:ss"+ "'";
            }


            
if (value == DBNull.Value)
            
{
                
return "NULL";
            }


            
throw new ParaWhereStringException(String.Format("invalid type of para={0}!",
                paraName));
        }


        
/// <summary>
        
/// 根据参数获取条件子句
        
/// </summary>
        
/// <param name="whereString">
        
/// 带参数的条件子句,如
        
/// "Price>@MinPrice and Price < @MaxPrice"
        /// </param>
        
/// <param name="parameters">参数列表</param>
        
/// <returns>获取实际的条件子句,如 "Price > 10 and Price < 100"</returns>

        public String GetWhereString(String whereString, List<SqlParameter> parameters)
        
{
            
if (parameters == null)
            
{
                
return whereString;
            }


            Clear();

            m_WhereString 
= whereString;
            SplitWhereString(whereString);

            Hashtable table 
= new Hashtable();

            
foreach (SqlParameter para in parameters)
            
{
                
if (para.Value == null)
                
{
                    table[
'@' + para.ParameterName.ToLower()] = DBNull.Value;
                }

                
else
                
{
                    table[
'@' + para.ParameterName.ToLower()] = para.Value;
                }

            }


            StringBuilder whereStr 
= new StringBuilder();

            
foreach (String str in m_Words)
            
{
                
if (str.Length > 0)
                
{
                    
if (str[0== '@')
                    
{
                        
object value = table[str.ToLower().Trim()];
                        
if (value == null)
                        
{
                            
throw new ParaWhereStringException(String.Format("para={0} does not in parameters!",
                                str));
                        }


                        whereStr.Append(GetParaValue(str, value));
                        
continue;
                    }

                }


                whereStr.Append(str);
            }


            
return whereStr.ToString();
        }


    }


    
/// <summary>
    
/// 数据存储扩展
    
/// </summary>

    public class DataRepositoryEx
    
{
        
/// <summary>
        
/// 获取分页的查询结果,查询语句必须是
        
/// Select 形式的,不能处理存储过程
        
/// </summary>
        
/// <param name="fields">where 子句前面的部分,不能有top关键字 如 “Price,ReleaseTime, RecName as Address”</param>
        
/// <param name="tableName">要查询的表名</param>
        
/// <param name="condition">带参数的 where子句,不包括where关键字 如 “Price > @MinPrice and Price < @MaxPrice”</param>
        
/// <param name="parameters">where子句的参数</param>
        
/// <param name="orderBy">order by 子句部分, 如果有Group by 也可以写在这里 如“order by ReleaseTime ASC”</param>
        
/// <param name="pageNo">页面号,从0开始编号</param>
        
/// <param name="pageLength">页面长度,即每页面记录数</param>
        
/// <param name="count">输出查询结果的总数</param>
        
/// <returns>以数据表形式返回查询结果集</returns>

        static public DataTable SelectPaged(String fields, String tableName,
            String condition, List
<SqlParameter> parameters, String orderBy, int pageNo, int pageLength, out int count)
        
{
            System.Diagnostics.Debug.Assert(pageNo 
>= 0);
            System.Diagnostics.Debug.Assert(pageLength 
> 0);

            ParaWhereString paraWhereStr 
= new ParaWhereString();
            String sqlCond 
= paraWhereStr.GetWhereString(condition, parameters);

            String sql;

            
if (condition == null)
            
{
                condition 
= "";
            }


            
if (condition == "")
            
{
                sql 
= String.Format("select count(*) cnt from {0}", tableName);
            }

            
else
            
{
                sql 
= String.Format("select count(*) cnt from {0} where {1}", tableName, sqlCond);
            }


            DataSet ds 
= DataRepository.Provider.ExecuteDataSet(CommandType.Text, sql);

            count 
= (int)ds.Tables[0].Rows[0]["cnt"];

            
int upperBound = (pageNo + 1* pageLength;

            
int lowerBound = pageNo * pageLength;

            
if (condition == "")
            
{
                sql 
= String.Format("select top {0} {1} from {2} ", upperBound, fields, tableName);
            }

            
else
            
{
                sql 
= String.Format("select top {0} {1} from {2} where {3} ", upperBound, fields, tableName, sqlCond);
            }


            
if (orderBy != "" && orderBy != null)
            
{
                sql 
+= orderBy;
            }


            ds 
= DataRepository.Provider.ExecuteDataSet(CommandType.Text, sql);

            
if (ds.Tables[0].Rows.Count <= lowerBound)
            
{
                ds.Tables[
0].Clear();
            }

            
else
            
{
                
for (int i = 0; i < lowerBound; i++)
                
{
                    ds.Tables[
0].Rows.RemoveAt(0);
                }

            }


            
return ds.Tables[0];
        }

    }



ParaWhereString  类用于将带参数的条件子句转换为不带参数的条件子句,供GetPaged,GetAll两个方法使用。这个类是一个通用的类,也可以用于
其他应用中获取带参数的条件子句的最终转换后的条件子句。

DataRepositoryEx 类提供分查询的方法。

ParaWhereString 调用示例
            ParaWhereString paraWhereString = new ParaWhereString();

            
string whereString = "price>@minPrice and price <= @maxPrice and str like '%adb''@aaa dsafj'";

            List
<SqlParameter> paras = new List<SqlParameter>();
               paras.Add(new SqlParameter("minPrice", 100));
               paras.Add(new SqlParameter("MaxPrice", 1000));



            String sql 
= paraWhereString.GetWhereString(whereString, paras);

            Console.WriteLine(sql);
输出结果:
price>100 and price <= 1000 and str like '%adb''@aaa dsafj'

DataRepositoryEx 调用示例

用于测试的表结构

use Test

   GO

Create Table Test
(
id 
int identity (1,1not null,
int

)


向表Test中插入若干条连续的记录

查询分页数据示例


        
int count;

        List
<SqlParameter> paras = new List<SqlParameter>();
        paras.Add(
new SqlParameter("min"3));
        paras.Add(
new SqlParameter("max"30));

        DataTable table 
= SecUser.Cert.BLL.DataRepositoryEx.SelectPaged("id, a""test..test""id >= @min and id < @max"
            paras, 
"order by id DESC"010out count);

        Response.Write(String.Format(
"Count={0}", count));

        
foreach(DataRow row in table.Rows)
        
{
            Response.Write(String.Format(
"</p>{0}", row["id"]));
        }

查询结果:

Count=27

29

28

27

26

25

24

23

22

21

20

原文地址:https://www.cnblogs.com/eaglet/p/832427.html