VC++基于LSP实现数据拦截


LSP即分层服务提供商,Winsock 作为应用程序的 Windows 的网络套接字工具,可以由称为“分层服务提供商”的机制进行扩展。Winsock LSP 可用于非常广泛的实用用途,包括 Internet 家长控制 (parental control) 和 Web 内容筛选。在以前版本的 Windows XP 中,删除不正确的(也称为“buggy”)LSP 可能会导致注册表中的 Winsock 目录损坏,潜在地导致所有网络连接的丢失。


请见代码



#define UNICODE
#define _UNICODE

#include <Winsock2.h>
#include <Ws2spi.h>
#include <Sporder.h>
#include <Windows.h>
#include <stdio.h>
#include <tchar.h>

#include "PhoenixLSP.h"

#include "../common/Debug.h"
#include "../common/PMacRes.h"
#include "Acl.h"

#pragma comment(lib, "Ws2_32.lib")


CAcl g_Acl;							// 访问列表,用来检查会话的访问权限

WSPUPCALLTABLE g_pUpCallTable;		// 上层函数列表。如果LSP创建了自己的伪句柄,才使用这个函数列表
WSPPROC_TABLE g_NextProcTable;		// 下层函数列表
TCHAR	g_szCurrentApp[MAX_PATH];	// 当前调用本DLL的程序的名称


BOOL APIENTRY DllMain( HANDLE hModule, 
                       DWORD  ul_reason_for_call, 
                       LPVOID lpReserved
					 )
{
	switch (ul_reason_for_call)
	{
	case DLL_PROCESS_ATTACH:
		{
			// 取得主模块的名称
			::GetModuleFileName(NULL, g_szCurrentApp, MAX_PATH);
		}
		break;
	}
	return TRUE;
}

int WSPAPI WSPStartup(
  WORD wVersionRequested,
  LPWSPDATA lpWSPData,
  LPWSAPROTOCOL_INFO lpProtocolInfo,
  WSPUPCALLTABLE UpcallTable,
  LPWSPPROC_TABLE lpProcTable
)
{
	ODS1(L"  WSPStartup...  %s \n", g_szCurrentApp);
	
	if(lpProtocolInfo->ProtocolChain.ChainLen <= 1)
	{	
		return WSAEPROVIDERFAILEDINIT;
	}
	
	// 保存向上调用的函数表指针(这里我们不使用它)
	g_pUpCallTable = UpcallTable;

	// 枚举协议,找到下层协议的WSAPROTOCOL_INFOW结构	
	WSAPROTOCOL_INFOW	NextProtocolInfo;
	int nTotalProtos;
	LPWSAPROTOCOL_INFOW pProtoInfo = GetProvider(&nTotalProtos);
	// 下层入口ID	
	DWORD dwBaseEntryId = lpProtocolInfo->ProtocolChain.ChainEntries[1];
	for(int i=0; i<nTotalProtos; i++)
	{
		if(pProtoInfo[i].dwCatalogEntryId == dwBaseEntryId)
		{
			memcpy(&NextProtocolInfo, &pProtoInfo[i], sizeof(NextProtocolInfo));
			break;
		}
	}
	if(i >= nTotalProtos)
	{
		ODS(L" WSPStartup:	Can not find underlying protocol \n");
		return WSAEPROVIDERFAILEDINIT;
	}

	// 加载下层协议的DLL
	int nError;
	TCHAR szBaseProviderDll[MAX_PATH];
	int nLen = MAX_PATH;
	// 取得下层提供程序DLL路径
	if(::WSCGetProviderPath(&NextProtocolInfo.ProviderId, szBaseProviderDll, &nLen, &nError) == SOCKET_ERROR)
	{
		ODS1(L" WSPStartup: WSCGetProviderPath() failed %d \n", nError);
		return WSAEPROVIDERFAILEDINIT;
	}
	if(!::ExpandEnvironmentStrings(szBaseProviderDll, szBaseProviderDll, MAX_PATH))
	{
		ODS1(L" WSPStartup:  ExpandEnvironmentStrings() failed %d \n", ::GetLastError());
		return WSAEPROVIDERFAILEDINIT;
	}
	// 加载下层提供程序
	HMODULE hModule = ::LoadLibrary(szBaseProviderDll);
	if(hModule == NULL)
	{
		ODS1(L" WSPStartup:  LoadLibrary() failed %d \n", ::GetLastError());
		return WSAEPROVIDERFAILEDINIT;
	}

	// 导入下层提供程序的WSPStartup函数
	LPWSPSTARTUP  pfnWSPStartup = NULL;
	pfnWSPStartup = (LPWSPSTARTUP)::GetProcAddress(hModule, "WSPStartup");
	if(pfnWSPStartup == NULL)
	{
		ODS1(L" WSPStartup:  GetProcAddress() failed %d \n", ::GetLastError());
		return WSAEPROVIDERFAILEDINIT;
	}

	// 调用下层提供程序的WSPStartup函数
	LPWSAPROTOCOL_INFOW pInfo = lpProtocolInfo;
	if(NextProtocolInfo.ProtocolChain.ChainLen == BASE_PROTOCOL)
		pInfo = &NextProtocolInfo;

	int nRet = pfnWSPStartup(wVersionRequested, lpWSPData, pInfo, UpcallTable, lpProcTable);
	if(nRet != ERROR_SUCCESS)
	{
		ODS1(L" WSPStartup:  underlying provider's WSPStartup() failed %d \n", nRet);
		return nRet;
	}

	// 保存下层提供者的函数表
	g_NextProcTable = *lpProcTable;

	// 传给上层,截获对以下函数的调用
	lpProcTable->lpWSPSocket = WSPSocket;
	lpProcTable->lpWSPCloseSocket = WSPCloseSocket;
	lpProcTable->lpWSPBind = WSPBind;
	lpProcTable->lpWSPAccept = WSPAccept;
	lpProcTable->lpWSPConnect = WSPConnect;
	lpProcTable->lpWSPSendTo = WSPSendTo;	
	lpProcTable->lpWSPRecvFrom = WSPRecvFrom; 

	FreeProvider(pProtoInfo);
	return nRet;
}



SOCKET WSPAPI WSPSocket(
	int			af,                               
	int			type,                             
	int			protocol,                         
	LPWSAPROTOCOL_INFOW lpProtocolInfo,   
	GROUP		g,                              
	DWORD		dwFlags,                        
	LPINT		lpErrno
)
{
	// 首先调用下层函数创建套节字
	SOCKET	s = g_NextProcTable.lpWSPSocket(af, type, protocol, lpProtocolInfo, g, dwFlags, lpErrno);
	if(s == INVALID_SOCKET)
		return s;

	// 调用CAcl类的CheckSocket函数,设置会话属性
	if (af == FROM_PROTOCOL_INFO)
		af = lpProtocolInfo->iAddressFamily;
	if (type == FROM_PROTOCOL_INFO)
		type = lpProtocolInfo->iSocketType;
	if (protocol == FROM_PROTOCOL_INFO)
		protocol = lpProtocolInfo->iProtocol;

	g_Acl.CheckSocket(s, af, type, protocol);

	return s;
}

int WSPAPI WSPCloseSocket(
	SOCKET		s,
	LPINT		lpErrno
)
{
	// 调用CAcl类的CheckCloseSocket函数,删除对应的会话
	g_Acl.CheckCloseSocket(s);
	return g_NextProcTable.lpWSPCloseSocket(s, lpErrno);
}

int WSPAPI WSPBind(SOCKET s, const struct sockaddr* name, int namelen, LPINT lpErrno)
{
	// 调用CAcl类的CheckBind函数,设置会话属性
	g_Acl.CheckBind(s, name);
	return g_NextProcTable.lpWSPBind(s, name, namelen, lpErrno);
}

int WSPAPI WSPConnect(
	SOCKET			s,
	const struct	sockaddr FAR * name,
	int				namelen,
	LPWSABUF		lpCallerData,
	LPWSABUF		lpCalleeData,
	LPQOS			lpSQOS,
	LPQOS			lpGQOS,
	LPINT			lpErrno
)
{
	ODS1(L" WSPConnect...	%s", g_szCurrentApp);

	// 检查是否允许连接到远程主机
	if(g_Acl.CheckConnect(s, name) != PF_PASS)
	{
		*lpErrno = WSAECONNREFUSED;
		ODS1(L" WSPConnect deny a query %s \n", g_szCurrentApp);
		return SOCKET_ERROR;
	} 

	return g_NextProcTable.lpWSPConnect(s, name, namelen, lpCallerData, lpCalleeData, lpSQOS, lpGQOS, lpErrno);
}

SOCKET WSPAPI WSPAccept(
	SOCKET			s,
	struct sockaddr FAR *addr,
	LPINT			addrlen,
	LPCONDITIONPROC	lpfnCondition,
	DWORD			dwCallbackData,
	LPINT			lpErrno
)
{
	ODS1(L"  PhoenixLSP:  WSPAccept  %s \n", g_szCurrentApp);

	// 首先调用下层函数接收到来的连接
	SOCKET	sNew	= g_NextProcTable.lpWSPAccept(s, addr, addrlen, lpfnCondition, dwCallbackData, lpErrno);
	
	// 检查是否允许,如果不允许,关闭新接收的连接
	if (sNew != INVALID_SOCKET && g_Acl.CheckAccept(s, sNew, addr) != PF_PASS)
	{
		int iError;
		g_NextProcTable.lpWSPCloseSocket(sNew, &iError);
		*lpErrno = WSAECONNREFUSED;
		return SOCKET_ERROR;
	}

	return sNew;
}


int WSPAPI WSPSendTo(
	SOCKET			s,
	LPWSABUF		lpBuffers,
	DWORD			dwBufferCount,
	LPDWORD			lpNumberOfBytesSent,
	DWORD			dwFlags,
	const struct sockaddr FAR * lpTo,
	int				iTolen,
	LPWSAOVERLAPPED	lpOverlapped,
	LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine,
	LPWSATHREADID	lpThreadId,
	LPINT			lpErrno
)
{
	ODS1(L" query send to... %s \n", g_szCurrentApp);
	
	// 检查是否允许发送数据
	if (g_Acl.CheckSendTo(s, lpTo) != PF_PASS)
	{
		int		iError;
		g_NextProcTable.lpWSPShutdown(s, SD_BOTH, &iError);
		*lpErrno = WSAECONNABORTED;

		ODS1(L" WSPSendTo deny query %s \n", g_szCurrentApp);

		return SOCKET_ERROR;
	}

	// 调用下层发送函数
	return g_NextProcTable.lpWSPSendTo(s, lpBuffers, dwBufferCount, 
							lpNumberOfBytesSent, dwFlags, lpTo, iTolen, 
								lpOverlapped, lpCompletionRoutine, lpThreadId, lpErrno);
}

int WSPAPI WSPRecvFrom (
	SOCKET			s,
	LPWSABUF		lpBuffers,
	DWORD			dwBufferCount,
	LPDWORD			lpNumberOfBytesRecvd,
	LPDWORD			lpFlags,
	struct sockaddr FAR * lpFrom,
	LPINT			lpFromlen,
	LPWSAOVERLAPPED lpOverlapped,
	LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine,
	LPWSATHREADID	lpThreadId,
	LPINT			lpErrno
)
{
	ODS1(L"  PhoenixLSP:  WSPRecvFrom %s \n", g_szCurrentApp);
	
	// 首先检查是否允许接收数据
	if(g_Acl.CheckRecvFrom(s, lpFrom) != PF_PASS)
	{
		int		iError;
		g_NextProcTable.lpWSPShutdown(s, SD_BOTH, &iError);
		*lpErrno = WSAECONNABORTED;

		ODS1(L" WSPRecvFrom deny query %s \n", g_szCurrentApp);
		return SOCKET_ERROR;
	}
	
	// 调用下层接收函数
	return g_NextProcTable.lpWSPRecvFrom(s, lpBuffers, dwBufferCount, lpNumberOfBytesRecvd, 
		lpFlags, lpFrom, lpFromlen, lpOverlapped, lpCompletionRoutine, lpThreadId, lpErrno);
}



LPWSAPROTOCOL_INFOW GetProvider(LPINT lpnTotalProtocols)
{
	DWORD dwSize = 0;
	int nError;
	LPWSAPROTOCOL_INFOW pProtoInfo = NULL;
	
	// 取得需要的长度
	if(::WSCEnumProtocols(NULL, pProtoInfo, &dwSize, &nError) == SOCKET_ERROR)
	{
		if(nError != WSAENOBUFS)
			return NULL;
	}
	
	pProtoInfo = (LPWSAPROTOCOL_INFOW)::GlobalAlloc(GPTR, dwSize);
	*lpnTotalProtocols = ::WSCEnumProtocols(NULL, pProtoInfo, &dwSize, &nError);
	return pProtoInfo;
}

void FreeProvider(LPWSAPROTOCOL_INFOW pProtoInfo)
{
	::GlobalFree(pProtoInfo);
}

/*

int WSPAPI WSPStartup(
  WORD wVersionRequested,
  LPWSPDATA lpWSPData,
  LPWSAPROTOCOL_INFO lpProtocolInfo,
  WSPUPCALLTABLE UpcallTable,
  LPWSPPROC_TABLE lpProcTable
)
{
	ODS1(L"  PhoenixLSP:  WSPStartup  %s \n", g_szCurrentApp);

	ODS1(L" %s", lpProtocolInfo->szProtocol);
	
	if(lpProtocolInfo->ProtocolChain.ChainLen <= 1)
	{	
		::OutputDebugString(L" Chain len <= 1 \n");
		return WSAEPROVIDERFAILEDINIT;
	}
	
	g_pUpCallTable = UpcallTable;
	int nTotalProtos;
	LPWSAPROTOCOL_INFOW pProtoInfo = GetProvider(&nTotalProtos);

	
	// 找到下层协议	
	WSAPROTOCOL_INFOW	NextProtocolInfo;
	// 下层入口ID	
	DWORD dwBaseEntryId = lpProtocolInfo->ProtocolChain.ChainEntries[1];
	for(int i=0; i<nTotalProtos; i++)
	{
		if(pProtoInfo[i].dwCatalogEntryId == dwBaseEntryId)
		{
			memcpy(&NextProtocolInfo, &pProtoInfo[i], sizeof(NextProtocolInfo));
			break;
		}
	}
	if(i >= nTotalProtos)
	{
		::OutputDebugString(L" Can not find next protocol <= 1 \n");
		return WSAEPROVIDERFAILEDINIT;
	}

	// 加载下层协议的DLL
	int nError;
	TCHAR szBaseProviderDll[MAX_PATH];
	int nLen = MAX_PATH;
	if(::WSCGetProviderPath(&NextProtocolInfo.ProviderId, szBaseProviderDll, &nLen, &nError) == SOCKET_ERROR)
	{
		::OutputDebugString(L" WSCGetProviderPath() failed \n");
		return WSAEPROVIDERFAILEDINIT;
	}
	if(!::ExpandEnvironmentStrings(szBaseProviderDll, szBaseProviderDll, MAX_PATH))
	{
		::OutputDebugString(L" ExpandEnvironmentStrings() failed \n");
		return WSAEPROVIDERFAILEDINIT;
	}
	HMODULE hModule = ::LoadLibrary(szBaseProviderDll);
	if(hModule == NULL)
	{
		::OutputDebugString(L" LoadLibrary() failed \n");
		return WSAEPROVIDERFAILEDINIT;
	}

	// 调用下层协议的WSPStartup函数
	LPWSPSTARTUP  proWSPStartup = NULL;
	proWSPStartup = (LPWSPSTARTUP)::GetProcAddress(hModule, "WSPStartup");
	if(proWSPStartup == NULL)
	{
		::OutputDebugString(L" GetProcAddress() failed \n");
		return WSAEPROVIDERFAILEDINIT;
	}

	LPWSAPROTOCOL_INFOW pInfo = lpProtocolInfo;
	if(NextProtocolInfo.ProtocolChain.ChainLen == BASE_PROTOCOL)
		pInfo = &NextProtocolInfo;

	int nRet = proWSPStartup(wVersionRequested, lpWSPData, pInfo, UpcallTable, lpProcTable);
	if(nRet != ERROR_SUCCESS)
	{
		ODS1(L" next layer's WSPStartup() failed %d \n ", nRet);
		return nRet;
	}

	// 保存下层协议的函数表
	g_NextProcTable = *lpProcTable;

	// 传给上层
	lpProcTable->lpWSPSocket = WSPSocket;
	lpProcTable->lpWSPCloseSocket = WSPCloseSocket;

	lpProcTable->lpWSPBind = WSPBind;

	// tcp
	lpProcTable->lpWSPAccept = WSPAccept;
	lpProcTable->lpWSPConnect = WSPConnect;

	// udp raw
	lpProcTable->lpWSPSendTo = WSPSendTo;	
	lpProcTable->lpWSPRecvFrom = WSPRecvFrom; 



	FreeProvider(pProtoInfo);
	return nRet;
}

*/




////////////////////////////////////////////////////////////////////////////////

/*	lpProcTable->lpWSPSend = WSPSend;
	lpProcTable->lpWSPRecv = WSPRecv;
*/
/*

  int WSPAPI WSPSend(
	SOCKET			s,
	LPWSABUF		lpBuffers,
	DWORD			dwBufferCount,
	LPDWORD			lpNumberOfBytesSent,
	DWORD			dwFlags,
	LPWSAOVERLAPPED	lpOverlapped,
	LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine,
	LPWSATHREADID	lpThreadId,
	LPINT			lpErrno
)
{
//	ODS1(L"  PhoenixLSP:  WSPSend  %s \n", g_szCurrentApp);

	// ?? 多个Buf如何处理
	if (g_Acl.CheckSend(s, lpBuffers[0].buf, *lpNumberOfBytesSent) != PF_PASS)
	{
		int		iError;
		g_NextProcTable.lpWSPShutdown(s, SD_BOTH, &iError);
		*lpErrno = WSAECONNABORTED;

		ODS(L" deny a send ");
		return SOCKET_ERROR;
	}

	return g_NextProcTable.lpWSPSend(s, lpBuffers, dwBufferCount, 
				lpNumberOfBytesSent, dwFlags, lpOverlapped, lpCompletionRoutine, lpThreadId, lpErrno);
}

int WSPAPI WSPRecv(
	SOCKET			s,
	LPWSABUF		lpBuffers,
	DWORD			dwBufferCount,
	LPDWORD			lpNumberOfBytesRecvd,
	LPDWORD			lpFlags,
	LPWSAOVERLAPPED	lpOverlapped,
	LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine,
	LPWSATHREADID	lpThreadId,
	LPINT			lpErrno
)
{
	ODS1(L"  PhoenixLSP:  WSPRecv  %s \n", g_szCurrentApp);

	if(lpOverlapped != NULL)
	{
		if(g_Acl.CheckRecv(s, NULL, 0) != PF_PASS)
		{
			int		iError;
			g_NextProcTable.lpWSPShutdown(s, SD_BOTH, &iError);
			*lpErrno = WSAECONNABORTED;

			ODS(L"deny a recv");
			return SOCKET_ERROR;
		}
		ODS(L" overlappped ");
	}

	int	iRet = g_NextProcTable.lpWSPRecv(s, lpBuffers, dwBufferCount, lpNumberOfBytesRecvd, lpFlags, lpOverlapped
				, lpCompletionRoutine, lpThreadId, lpErrno);

	if(iRet != SOCKET_ERROR && lpOverlapped == NULL)
	{
		if(g_Acl.CheckRecv(s, lpBuffers[0].buf, *lpNumberOfBytesRecvd) != PF_PASS)
		{
			int		iError;
			g_NextProcTable.lpWSPShutdown(s, SD_BOTH, &iError);
			*lpErrno = WSAECONNABORTED;

			ODS(L"deny a recv");
			return SOCKET_ERROR;
		}
	}
	return iRet;
}

  */


原文地址:https://www.cnblogs.com/new0801/p/6177674.html