题面
ZYB有一个有N个节点的树,现在他希望你求出与每个节点距离不超过K的节点数。两个节点(x,y)之间的距离定义为x到y的最短路径上经过的边数。为节省读入和输出时间,我们使用以下方式:读入:我们有两个数字A和B,让fai是节点i的父亲,fa1 = 0,fai =(A * i + B)%(i-1)+1其中i∈[2,N]。对于输出:让ansi成为节点i的答案,你只需要输出所有ansi的xor总和。
分析
先开始只想到了在同一棵子树里面的情况,真的是奇蠢无比...
分两种情况
f[u][j]表示以u为根的子树中,与u距离为j的点的数量。其中0<=j<=k。显然f[u][j]+=f[v][j-1]
dp[u][j]表示u这个点到不是以它为根的子树的结点距离为j的点的数量。首先我们已经知道了所有f[u][j]值了,其中对于根结点只存在第一种情况,所以从根往下算答案。
1.某个不在v的子树且不在fa的子树中的点到v这个点距离为j,那么它到v的父亲u距离就为j-1,这部分节点数量为dp[u][j-1]。
2.某个不在v的子树但在fa的子树中的点到v(即u兄弟子树中的结点到v)的距离为j,那么它到v的父亲u距离就为j-1,这部分结点数量为f[u][j-1]-f[v][j-2]。
因为f[u][j-1]中包含了在v的子树中的点,所以要减去到v的距离为j-2的点的数量。
如图,这样理解
求不在粉色的点的子树中且与粉色的点距离为2的点数量
蓝色,红色为合法路径。但黄色线的表达什么?黄色线下方末端是f[u][2]能到达的点,但是并不合法,因为他们在v的子树中,所以减去这两个点,即f[v][1]。
代码
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; #define N 600010 int a,b,k,t,n,cnt,tot,ans; int first[N],f[N][15],dp[N][15]; struct email { int u,v; int nxt; }e[N*4]; inline void add(int u,int v) { e[++cnt].nxt=first[u];first[u]=cnt; e[cnt].u=u;e[cnt].v=v; } inline void init() { ans=cnt=0; memset(f,0,sizeof(f)); memset(dp,0,sizeof(dp)); memset(first,0,sizeof(first)); } void dfs(int u,int fa) { f[u][0]=1; for(int i=first[u];i;i=e[i].nxt) { int v=e[i].v; if(v==fa)continue; dfs(v,u); for(int j=1;j<=k;j++) f[u][j]+=f[v][j-1]; } } void dfs2(int u,int fa) { for(int i=first[u];i;i=e[i].nxt) { int v=e[i].v; if(v==fa)continue; for(int j=2;j<=k;j++) dp[v][j]=f[u][j-1]-f[v][j-2]+dp[u][j-1]; dp[v][1]=f[u][0]; dfs2(v,u); } } int main() { scanf("%d",&t); while(t--) { init(); scanf("%d%d%d%d",&n,&k,&a,&b); for(int u=2;u<=n;u++) add(u,((long long)a*u+b)%(u-1)+1),add(((long long)a*u+b)%(u-1)+1,u); dfs(1,0);dfs2(1,0); for(int i=1;i<=n;i++) { tot=0; for(int j=0;j<=k;j++) tot+=(f[i][j]+dp[i][j]); ans^=tot; } printf("%d ",ans); } return 0; }