【XSY3241】暴风士兵

【XSY3241】暴风士兵

他是暴风士兵,我是伞兵。

我们考虑令(C(x)=sum_{i=0}^{exp}(exp-i)x^i)(P(x))为扣(i)滴血的概率(P_i)的生成函数。

那么不难发现,对于一个时间(t),答案即为:

[ans_t=sum_{i=0}^{exp}C_iP_i ]

然后我们不难发现,每经过一个时间点(t)(P(x) imes=(P_ix+(1-P_i)))

但这样似乎还是(n^2)的,做不了呀。

我们考虑随便设一个断点(k),然后让(A(x)=prod_{i=1}^k(P_ix+(1-P_i))),(B(x)=prod_{i=k+1}^t(P_ix+(1-P_i))),于是就有:

[egin{aligned} ans_t &= sum_{i=0}^{exp} C_i [x^i]A(x)B(x)\ &= sum_{i=0}^{exp} C_i sum_{j=0}^i A_jB_{i-j}\ &= sum_{i=0}^{exp} B_i sum_{j=0}^{exp-i} C_{j+i}A_j\ &= sum_{i=0}^{exp} C'i[x^i]B(x) end{aligned} ]

于是我们对于每一个(t),令(k=t-1),可以用分治(NTT)加上减法卷积算出([1,t-1])(C'),然后点乘上(i)处的(B)就可以了。此处的(B)只有两项,(O(1))即可。

(于是一个强制在线问题被分治搞掉了,真的高。)

(还有更草的,作为蒟蒻的我之前居然没有试过减法NTT的分治www然后卡了半天(

#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int dec(int x,int y){return x<y?x-y+mod:x-y;}
inline int mul(int x,int y){return 1ll*x*y%mod;}
inline int qpow(int n,int k){
    int ret=1;
    while(k){
        if(k&1)ret=mul(ret,n);
        n=mul(n,n);
        k>>=1;
    }
    return ret;
}
int G[2][270010][20];
void init(int lim){
    for(int mid=1,dep=0;mid<lim;mid<<=1,dep++){
        int len=mid<<1;
        int gn=qpow(3,(mod-1)/len);
        int ign=qpow(gn,mod-2);
        int g=1,ig=1;
        for(int j=0;j<mid;++j,g=mul(g,gn),ig=mul(ig,ign)){
            G[1][j][dep]=g,G[0][j][dep]=ig;
        }
    }
}
int rev[270010];
void NTT(int *A,int lim,int opt){
    for(int i=0;i<lim;++i){
        rev[i]=(rev[i>>1]>>1)|((i&1)*(lim>>1));
        if(i<rev[i])swap(A[i],A[rev[i]]);
    }
    for(int mid=1,dep=0;mid<lim;mid<<=1,dep++){
        int len=mid<<1;
        for(int i=0;i<lim;i+=len){
            for(int j=0;j<mid;++j){
                int x=A[i+j],y=mul(G[opt][j][dep],A[i+j+mid]);
                A[i+j]=add(x,y);
                A[i+j+mid]=dec(x,y);
            }
        }
    }
    if(!opt){
        int div=qpow(lim,mod-2);
        for(int i=0;i<lim;++i)A[i]=mul(A[i],div);
    }
}
int lst;
vector<int> c[400010];
vector<int> p[400010];
#define ls (o<<1)
#define rs (o<<1|1)
void solve(int o,int l,int r){
	if(l==r){
		int nowp;
		scanf("%d",&nowp);
		p[o].push_back(dec(1,add(nowp,lst)));p[o].push_back(add(nowp,lst));
//		cout<<p[o][0]<<" "<<p[o][1]<<" "<<c[o][0]<<" "<<c[o][1]<<endl;
		printf("%d
",lst=(add(mul(c[o][0],p[o][0]),mul(c[o][1],p[o][1]))));
		return;
	}
	static int A[270010],B[270010];
	int mid=(l+r)/2;
	int len=(r-l+1),lenl=(mid-l+1),lenr=(r-mid);
	for(int i=0;i<=lenl;++i)c[ls].push_back(c[o][i]);
	solve(ls,l,mid);
	int lim=1;
	while(lim<=len+lenl)lim<<=1;
	for(int i=0;i<=len;++i)A[i]=c[o][i];
	for(int i=0;i<=lenl;++i)B[i]=p[ls][lenl-i];
	NTT(A,lim,1),NTT(B,lim,1);
	for(int i=0;i<lim;++i)A[i]=mul(A[i],B[i]);
	NTT(A,lim,0);
	for(int i=0;i<=lenr;++i)c[rs].push_back(A[i+lenl]);
	for(int i=0;i<lim;++i)A[i]=B[i]=0;
	solve(rs,mid+1,r);
	lim=1;
	while(lim<=lenl+lenr)lim<<=1;
	for(int i=0;i<=lenl;++i)A[i]=p[ls][i];
	for(int i=0;i<=lenr;++i)B[i]=p[rs][i];
	NTT(A,lim,1),NTT(B,lim,1);
	for(int i=0;i<lim;++i)A[i]=mul(A[i],B[i]);
	NTT(A,lim,0);
	for(int i=0;i<=len;++i)p[o].push_back(A[i]);
	for(int i=0;i<lim;++i)A[i]=B[i]=0;
	p[ls].clear(),p[rs].clear();
}
#undef ls
#undef rs
int main(){
	int exp,n;
	scanf("%d%d",&exp,&n);
	lst=exp;
	int lim=1;
	while(lim<=(n<<1))lim<<=1;
	init(lim);
	for(int i=0;i<=exp;++i)c[1].push_back(exp-i);
	for(int i=exp+1;i<=n;++i)c[1].push_back(0);
	solve(1,1,n);
}
原文地址:https://www.cnblogs.com/youddjxd/p/15091306.html