[loj3340]命运

容斥,强制若干条链不重要,即有$2^{n-1-s}$种(其中$s$为这些链的并所覆盖的边数),暴力将选中的链打标记,时间复杂度$o(m^{2}2^{m}+nlog_{2}n)$(预处理出这$2m$个点的虚树),期望得分32(实际得分40)

考虑在计算$s$时可以差分来统计,时间复杂度可以做到$o(m2^{m}+nlog_{2}n)$,期望得分40

考虑另一种优化方法:按照dfs的顺序去枚举,每一次枚举$v=k$的所有链的状态,并维护当前子树内的被选择的深度最小的$u$(如果$u$在$k$子树中令$u=k$),这样的时间复杂度也是$o(m2^{m}+nlog_{2}n)$

观察到当处理完$k$子树后,影响答案的只有:1.深度最小的$u$(的深度);2.所覆盖的边数$s$(仅考虑子树内部),那么用$f[k][u][s]$表示对应的方案数(方案有“正负”),时间复杂度$o(n^{4})$

考虑优化,不妨令$dp[k][u]=sum_{i=0}^{n-1}f[k][u][i]cdot 2^{sz[k]-1-i}$,后者具有可乘性,因此可以转移(注意:转移过程中$sz[k]-1$的意义为考虑过的边数量),时间复杂度$o(n^{2})$

记$S[k][u]=sum_{i=u}^{dep_{k}}dp[k][i]$,简单化简,可以发现新的转移式为$S[k][u]=prod_{son}(S[son][u]+S[son][dep_{son}])$

考虑用线段树来维护这个dp数组,那么要支持:1.区间加;2.对应位置相乘

维护加法标记和乘法标记,然后线段树合并即可,时间复杂度$o(nlog_{2}n)$

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 #define N 500005
 4 #define mod 998244353
 5 #define mid (l+r>>1)
 6 #define pii pair<int,int>
 7 #define fi first
 8 #define se second
 9 struct ji{
10     int nex,to;
11 }edge[N<<1];
12 vector<int>v[N];
13 int V,E,n,m,x,y,head[N],s[N],r[N],ls[N*40],rs[N*40];
14 pii tag[N*40];
15 void add(int x,int y){
16     edge[E].nex=head[x];
17     edge[E].to=y;
18     head[x]=E++;
19 }
20 void upd(int k,pii x){
21     tag[k].fi=1LL*tag[k].fi*x.fi%mod;
22     tag[k].se=(1LL*tag[k].se*x.fi+x.se)%mod;
23 }
24 void down(int k){
25     if (ls[k])upd(ls[k],tag[k]);
26     if (rs[k])upd(rs[k],tag[k]);
27     tag[k]=make_pair(1,0);
28 }
29 void update(int &k,int l,int r,int x,int y,int z){
30     if ((l>y)||(x>r))return; 
31     if (!k){
32         k=++V;
33         tag[k]=make_pair(1,0);
34     }
35     if ((x<=l)&&(r<=y)){
36         tag[k].fi=(tag[k].fi+z)%mod;
37         return;
38     }
39     down(k);
40     update(ls[k],l,mid,x,y,z);
41     update(rs[k],mid+1,r,x,y,z);
42 }
43 int query(int k,int l,int r,int x){
44     if (l==r)return tag[k].se;
45     down(k);
46     if (x<=mid)return query(ls[k],l,mid,x);
47     return query(rs[k],mid+1,r,x);
48 }
49 int merge(int k1,int k2){
50     if ((!k1)||(!k2))return k1+k2;
51     if ((!ls[k1])&&(!rs[k1]))swap(k1,k2); 
52     if ((!ls[k2])&&(!rs[k2])){
53         upd(k1,make_pair(tag[k2].se,0));
54         return k1;
55     }
56     down(k1);
57     down(k2);
58     ls[k1]=merge(ls[k1],ls[k2]);
59     rs[k1]=merge(rs[k1],rs[k2]);
60     return k1;
61 }
62 void dfs(int k,int fa,int sh){
63     s[k]=sh;
64     if (!v[k].size())update(r[k],0,n,0,s[k],1);
65     else{
66         for(int i=1;i<v[k].size();i++)
67             if (s[v[k][i]]>s[v[k][0]])v[k][0]=v[k][i];
68         update(r[k],0,n,s[v[k][0]],s[k],mod-1);
69     }
70     for(int i=head[k];i!=-1;i=edge[i].nex)
71         if (edge[i].to!=fa){
72             dfs(edge[i].to,k,sh+1);
73             r[k]=merge(r[k],r[edge[i].to]);
74         }
75     if (k>1)update(r[k],0,n,0,s[k],query(r[k],0,n,s[k]));
76 }
77 int main(){
78     freopen("destiny.in","r",stdin);
79     freopen("destiny.out","w",stdout);
80     scanf("%d",&n);
81     memset(head,-1,sizeof(head));
82     for(int i=1;i<n;i++){
83         scanf("%d%d",&x,&y);
84         add(x,y);
85         add(y,x);
86     }
87     scanf("%d",&m);
88     for(int i=1;i<=m;i++){
89         scanf("%d%d",&x,&y);
90         v[y].push_back(x);
91     }
92     dfs(1,0,0);
93     printf("%d",query(r[1],0,n,0));
94     return 0;
95 }
View Code
原文地址:https://www.cnblogs.com/PYWBKTDA/p/13693651.html