bzoj 4543 HOTEL 加强版

题目大意:

求树上取三个点这三个点两两距离相等的方案数

思路:

远古时候的$n^2$做法是换根 但那样无法继续优化了

学习了一波长链剖分

考虑如何在一棵树上进行dp

设$f[i][j]$表示以$i$为根的子树内与$i$的距离为$j$的点数量

$g[i][j]$表示以$i$为根的子树内满足与lca距离为$d$且lca与$i$的距离为$d-j$的点对数(lca在子树内)

对于每个子树 对答案的贡献为$g[x][0]$与对每一个新进来的子树$sum _{i=0} ^{mxd[v]} f[x][i-1]*g[v][i]+g[x][i+1]*f[v][i]$

更新$f,g$的时候除了正常的继承的子树 $g[x][i+1]+=f[x][i+1]*f[v][i]$

 1 #include<iostream>
 2 #include<cstdio>
 3 #include<cstring>
 4 #include<cstdlib>
 5 #include<cmath>
 6 #include<algorithm>
 7 #include<queue>
 8 #include<vector>
 9 #include<map>
10 #include<set>
11 #define ll long long
12 #define inf 2139062143
13 #define MAXN 100100
14 #define MOD 998244353
15 #define rep(i,s,t) for(register int i=(s),i##__end=(t);i<=i##__end;++i)
16 #define dwn(i,s,t) for(register int i=(s),i##__end=(t);i>=i##__end;--i)
17 #define ren for(register int i=fst[x];i;i=nxt[i])
18 #define pb(i,x) vec[i].push_back(x)
19 #define pls(a,b) (a+b)%MOD
20 #define mns(a,b) (a-b+MOD)%MOD
21 #define mul(a,b) (1LL*(a)*(b))%MOD
22 using namespace std;
23 inline int read()
24 {
25     int x=0,f=1;char ch=getchar();
26     while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
27     while(isdigit(ch)) {x=x*10+ch-'0';ch=getchar();}
28     return x*f;
29 }
30 int n,*f[MAXN],*g[MAXN],tmp[MAXN<<2],*id=tmp,ans;
31 int fst[MAXN],nxt[MAXN<<1],to[MAXN<<1],cnt,mxd[MAXN],son[MAXN];
32 void add(int u,int v) {nxt[++cnt]=fst[u],fst[u]=cnt,to[cnt]=v;}
33 void dfs(int x,int pa)
34 {
35     ren if(to[i]^pa) {dfs(to[i],x);if(mxd[to[i]]>mxd[son[x]]) son[x]=to[i];}
36     mxd[x]=mxd[son[x]]+1;
37 }
38 void New(int x) {f[x]=id,id+=mxd[x]<<1,g[x]=id,id+=mxd[x]<<1;}
39 void dp(int x,int pa)
40 {
41     if(son[x]) {f[son[x]]=f[x]+1,g[son[x]]=g[x]-1;dp(son[x],x);}
42     f[x][0]=1,ans+=g[x][0];
43     ren if(to[i]^pa&&to[i]^son[x])
44     {
45         New(to[i]);dp(to[i],x);
46         rep(j,0,mxd[to[i]])
47         {
48             ans+=g[x][j+1]*f[to[i]][j];
49             if(j) ans+=f[x][j-1]*g[to[i]][j];
50         }
51         rep(j,0,mxd[to[i]])
52         {
53             g[x][j+1]+=f[x][j+1]*f[to[i]][j],f[x][j+1]+=f[to[i]][j];
54             if(j) g[x][j-1]+=g[to[i]][j];
55         }
56     }
57 }
58 int main()
59 {
60     n=read();int a,b;rep(i,2,n) a=read(),b=read(),add(a,b),add(b,a);
61     dfs(1,0);New(1);dp(1,0);printf("%d
",ans);
62 }
View Code
原文地址:https://www.cnblogs.com/yyc-jack-0920/p/10470411.html