洛谷P5206 数树

题意:

task0,给定两棵树T1,T2,取它们公共边(两端点相同)加入一张新的图,记新图连通块个数为x,求yx

task1,给定T1,求所有T2的task0之和。

task2,求所有T1的task1之和。

解:y = 1的时候特殊处理,就是总方案数。

task0,显然按照题意模拟即可。

task1,对某个T2,设有k条边相同,那么连通块数就是n - k。要求的就是

对于每个T2,前面yn都是一样的,所以直接去掉,最后乘上即可。关注后面这个东西怎么求。令y' = 1/y,E是公共边集。

注意到

这里下式是枚举边集E的子集S,对每个的子集贡献求和。

注意上式先枚举a再求组合数,相当于枚举在边集里选a条边,然后枚举选哪a条边。也就是枚举子集。

也就是下面这段话想表达的。摘自Cyhlnj。里面还提到了一个n3的矩阵树定理做法,神奇。

容斥写法在下不会T_T

下一步,把S提前枚举,在不同的E中同一个S的贡献总是相同的。考虑一个S会对哪些E产生贡献,也就是它的贡献会被计算多少次。

这|S|条边会形成若干个连通块。这些连通块通过加上一些边可以形成树。这些新边没有任何限制,于是就是连通块的生成树计数。

这里又有若干种推法......个人认为最简单的是利用prufer序列求解。

摘自Joker_69

令z = y' - 1,m = 边集为S时的连通块数 = n - |S|,第i号连通块有ai个点,于是我们的答案变成了这样:

这个东西怎么求呢?注意到在T1中选择任意的边集S等价于把T1划分为若干个连通块,用这些边连起来。于是就考虑树形DP。

这后面这个求积,要乘上每个连通块的大小,有个暴力是f[x][i]表示x为根,x所在连通块大小为i的所有方案权值和。

n2过不了,于是换个思路就是在每个连通块中选一个关键点的方案数。

因为是以联通块为依据DP,所以变形一下,加上之前忽略的yn,我们有:

于是状态设计有f[x][0/1]表示在以x为根的子树中,x所在连通块是否选择了关键点的所有方案权值之和。

每个联通块的贡献是z-1n,且我们只在关键点被选出来的那一瞬间计算这个联通块的贡献。

同时由于每个连通块的贡献要乘起来,那么所有方案之和还是要乘起来,等价于每个方案两两求积再相加。

口胡了半天还是写一下方程吧。

f[x][0] = 1;
f[x][1] = invz * n % MO;

LL t0 = f[x][0] * f[y][1] % MO + f[x][0] * f[y][0] % MO;
LL t1 = f[x][0] * f[y][1] % MO + f[x][1] * f[y][0] % MO + f[x][1] * f[y][1] % MO;
f[x][0] = t0 % MO;
f[x][1] = t1 % MO;

task2:问题变得严重起来......

跟task1一样,对于某个T1和T2的组合,它的贡献仍能拆成它的子集的贡献。

设g(E)为给定E这个边集之后的生成树个数,由task1可得g(E) = nm-2∏ai

枚举E为一定相同的边集,剩下的边随便连。那么对于T1的g(E)种情况,T2都有g(E)种情况。

所以E这个边集的贡献为z|E|g2(E)。

m还是连通块数,我们暴力展开g(E),并把与m有关的项放到∏里面,无关的提到外面,令r = n2/z,那么答案就是:

接下来这一步很毒瘤...我们考虑这个式子有什么实际意义。

前面的边集把这个图分成了若干个森林。每个连通块一定是树。后面相当于给每个连通块赋了ai2r的权值,并把权值乘起来作为这个边集的贡献。

设fi为大小为i的树的贡献,对应的EGF是F(x),gi为大小为i的图的贡献,对应的EGF是G(x)

那么有这样一个式子:G(x) = eF(x)

考虑fi是多少:每种树的权值都是i2r,一共有ii-2种树,贡献加起来是iir。

这样就对F(x)做exp,然后拿G(x)的第n项出来搞一搞就是答案了。


多项式操作别写错了......我一开始WA了20分是因为有这样的一句话:n * n

然后两个n乘起来爆int了......这题神奇的一批...

  1 #include <cstdio>
  2 #include <algorithm>
  3 #include <cstring>
  4 
  5 typedef long long LL;
  6 const int N = 100010;
  7 const LL MO = 998244353;
  8 
  9 inline LL qpow(LL a, LL b) {
 10     LL ans = 1;
 11     a %= MO;
 12     while(b) {
 13         if(b & 1) ans = ans * a % MO;
 14         a = a * a % MO;
 15         b = b >> 1;
 16     }
 17     return ans;
 18 }
 19 
 20 struct Edge {
 21     int nex, v;
 22 }edge[N << 1]; int tp;
 23 
 24 int n, e[N];
 25 LL Y, z;
 26 
 27 inline void add(int x, int y) {
 28     tp++;
 29     edge[tp].v = y;
 30     edge[tp].nex = e[x];
 31     e[x] = tp;
 32     return;
 33 }
 34 
 35 namespace t0 {
 36     int fa[N];
 37     void DFS(int x, int f) {
 38         fa[x] = f;
 39         for(int i = e[x]; i; i = edge[i].nex) {
 40             int y = edge[i].v;
 41             if(y == f) continue;
 42             DFS(y, x);
 43         }
 44         return;
 45     }
 46     inline void solve() {
 47         if(Y == 1) {
 48             puts("1");
 49             return;
 50         }
 51         for(int i = 1, x, y; i < n; i++) {
 52             scanf("%d%d", &x, &y);
 53             add(x, y);
 54             add(y, x);
 55         }
 56         DFS(1, 0);
 57         int k = 0;
 58         for(int i = 1, x, y; i < n; i++) {
 59             scanf("%d%d", &x, &y);
 60             if(fa[x] == y || fa[y] == x) {
 61                 k++;
 62             }
 63         }
 64         LL ans = qpow(Y, n - k);
 65         printf("%lld
", ans);
 66         return;
 67     }
 68 }
 69 
 70 namespace t1 {
 71     LL f[N][2], invz;
 72     void DFS(int x, int father) {
 73         f[x][0] = 1;
 74         f[x][1] = invz * n % MO;
 75         //printf("x = %d fa = %d 
", x, father);
 76         for(int i = e[x]; i; i = edge[i].nex) {
 77             int y = edge[i].v;
 78             //printf("y = %d 
", y);
 79             if(y == father) continue;
 80             DFS(y, x);
 81             LL t0 = f[x][0] * f[y][1] % MO + f[x][0] * f[y][0] % MO;
 82             LL t1 = f[x][0] * f[y][1] % MO + f[x][1] * f[y][0] % MO + f[x][1] * f[y][1] % MO;
 83             f[x][0] = t0 % MO;
 84             f[x][1] = t1 % MO;
 85         }
 86         //printf("X = %d f[x][0] = %lld f[x][1] = %lld 
", x, f[x][0], f[x][1]);
 87         return;
 88     }
 89     inline void solve() {
 90         if(Y == 1) {
 91             LL ans = qpow(n, n - 2);
 92             printf("%lld
", ans);
 93             return;
 94         }
 95         z = qpow(Y, MO - 2); z = (z - 1 + MO) % MO;
 96         invz = qpow(z, MO - 2);
 97         for(int i = 1, x, y; i < n; i++) {
 98             scanf("%d%d", &x, &y);
 99             add(x, y); add(y, x);
100         }
101         DFS(1, 0);
102         LL ans = f[1][1] * qpow(n, MO - 3) % MO * qpow(z, n) % MO * qpow(Y, n) % MO;
103         printf("%lld
", ans);
104         return;
105     }
106 }
107 
108 namespace t2 {
109 
110     typedef LL arr[N * 4];
111     const LL G = 3;
112 
113     int r[N * 4];
114     arr A, B, a, b, inv_t, exp_t, ln_t, ln_t2;
115     LL pw[N];
116 
117     inline void prework(int n) {
118         static int R = 0;
119         if(R == n) return;
120         R = n;
121         int lm = 1;
122         while((1 << lm) < n) lm++;
123         for(int i = 0; i < n; i++) {
124             r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lm - 1));
125         }
126         return;
127     }
128 
129     inline void NTT(LL *a, int n, int f) {
130         prework(n);
131         for(int i = 0; i < n; i++) {
132             if(i < r[i]) std::swap(a[i], a[r[i]]);
133         }
134         for(int len = 1; len < n; len <<= 1) {
135             LL Wn = qpow(G, (MO - 1) / (len << 1));
136             if(f == -1) Wn = qpow(Wn, MO - 2);
137             for(int i = 0; i < n; i += (len << 1)) {
138                 LL w = 1;
139                 for(int j = 0; j < len; j++) {
140                     LL t = a[i + len + j] * w % MO;
141                     a[i + len + j] = (a[i + j] - t) % MO;
142                     a[i + j] = (a[i + j] + t) % MO;
143                     w = w * Wn % MO;
144                 }
145             }
146         }
147         if(f == -1) {
148             LL inv = qpow(n, MO - 2);
149             for(int i = 0; i < n; i++) {
150                 a[i] = a[i] * inv % MO;
151             }
152         }
153         return;
154     }
155 
156     void Inv(const LL *a, LL *b, int n) {
157         if(n == 1) {
158             b[0] = qpow(a[0], MO - 2);
159             b[1] = 0;
160             return;
161         }
162         Inv(a, b, n >> 1);
163         /// ans = b[i] * (2 - a[i] * b[i])
164         memcpy(A, a, n * sizeof(LL)); memset(A + n, 0, n * sizeof(LL));
165         memcpy(B, b, n * sizeof(LL)); memset(B + n, 0, n * sizeof(LL));
166         NTT(A, n << 1, 1); NTT(B, n << 1, 1);
167         for(int i = 0; i < (n << 1); i++) b[i] = B[i] * (2 - A[i] * B[i] % MO) % MO;
168         NTT(b, n << 1, -1);
169         memset(b + n, 0, n * sizeof(LL));
170         return;
171     }
172 
173     inline void getInv(const LL *a, LL *b, int n) {
174         int len = 1;
175         while(len < n) len <<= 1;
176         memcpy(inv_t, a, n * sizeof(LL)); memset(inv_t + n, 0, (len - n) * sizeof(LL));
177         Inv(inv_t, b, len);
178         memset(b + n, 0, (len - n) * sizeof(LL));
179         return;
180     }
181 
182     inline void der(const LL *a, LL *b, int n) {
183         for(int i = 0; i < n - 1; i++) {
184             b[i] = a[i + 1] * (i + 1) % MO;
185         }
186         b[n - 1] = 0;
187         return;
188     }
189 
190     inline void ter(const LL *a, LL *b, int n) {
191         for(int i = n - 1; i >= 1; i--) {
192             b[i] = a[i - 1] * qpow(i, MO - 2) % MO;
193         }
194         b[0] = 0;
195         return;
196     }
197 
198     inline void getLn(const LL *a, LL *b, int n) {
199         getInv(a, ln_t, n);
200         der(a, ln_t2, n);
201         int len = 1;
202         while(len < 2 * n) len <<= 1;
203         memset(ln_t + n, 0, (len - n) * sizeof(LL));
204         memset(ln_t2 + n, 0, (len - n) * sizeof(LL));
205         NTT(ln_t, len, 1); NTT(ln_t2, len, 1);
206         for(int i = 0; i < len; i++) b[i] = ln_t[i] * ln_t2[i] % MO;
207         NTT(b, len, -1);
208         memset(b + n, 0, (len - n) * sizeof(LL));
209         ter(b, b, n);
210         return;
211     }
212 
213     void Exp(const LL *a, LL *b, int n) {
214         if(n == 1) {
215             b[0] = 1;
216             b[1] = 0;
217             return;
218         }
219         Exp(a, b, n >> 1);
220         /// ans = b * (1 + a - ln b)
221         getLn(b, exp_t, n);
222         for(int i = 0; i < n; i++) A[i] = (a[i] - exp_t[i]) % MO;
223         A[0] = (A[0] + 1) % MO;
224         memset(A + n, 0, n * sizeof(LL));
225         memcpy(B, b, n * sizeof(LL)); memset(B + n, 0, n * sizeof(LL));
226         NTT(A, n << 1, 1); NTT(B, n << 1, 1);
227         for(int i = 0; i < (n << 1); i++) b[i] = A[i] * B[i] % MO;
228         NTT(b, n << 1, -1);
229         memset(b + n, 0, n * sizeof(LL));
230         return;
231     }
232 
233     inline void getExp(const LL *a, LL *b, int n) {
234         int len = 1;
235         while(len < n) len <<= 1;
236         Exp(a, b, len);
237         memset(b + n, 0, (len - n) * sizeof(LL));
238         return;
239     }
240 
241     inline void solve() {
242         if(Y == 1) {
243             LL t = qpow(n, n - 2);
244             printf("%lld
", t * t % MO);
245             return;
246         }
247 
248         LL z = (qpow(Y, MO - 2) - 1) % MO;
249         LL r = 1ll * n * n % MO * qpow(z, MO - 2) % MO;
250 
251         pw[0] = 1;
252         for(int i = 1; i <= n; i++) {
253             pw[i] = pw[i - 1] * i % MO;
254             a[i] = qpow(i, i) * r % MO * qpow(pw[i], MO - 2) % MO;
255         }
256         getExp(a, b, n + 1);
257         LL ans = b[n] * pw[n] % MO;
258         ans = ans * qpow(Y, n) % MO * qpow(z, n) % MO * qpow(n, MO - 5) % MO;
259         printf("%lld
", (ans + MO) % MO);
260         return;
261     }
262 }
263 
264 int main() {
265     
266     int f;
267     scanf("%d%lld%d", &n, &Y, &f);
268     if(f == 0) {
269         t0::solve();
270         return 0;
271     }
272     if(f == 1) {
273         t1::solve();
274         return 0;
275     }
276     t2::solve();
277     return 0;
278 }
AC代码

以蒟蒻视角写了题解,以后还要继续努力!

感谢:

原文地址:https://www.cnblogs.com/huyufeifei/p/10453393.html