51nod 1348 乘积之和

Description

给出由N个正整数组成的数组A,有Q次查询,每个查询包含一个整数K,从数组A中任选K个(K <= N)把他们乘在一起得到一个乘积。求所有不同的方案得到的乘积之和,由于结果巨大,输出Mod 100003的结果即可。例如:1 2 3,从中任选1个共3种方法,{1} {2} {3},和为6。从中任选2个共3种方法,{1 2} {1 3} {2 3},和为2 + 3 + 6 = 11

Solution

答案就是 (Pi_{i=1}^{n} (a_i*x_i+1)) 这个多项式的 (x^k) 项的系数
直接分治 (NTT) 求解即可
类似于线段树合并的方法,把多项式合并

值得注意的是这题的模数不是费马质数,可以用到一个套路:用乘积大于 (P^2*n) 的两个费马质数代替,最后 (CRT) 一下

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=200005,M=100003;
int a[N],f[20][N],P[2]={998244353,1004535809},m,R[N];
inline int qm(int x,int k,int mod){
	int sum=1;
	while(k){
		if(k&1)sum=1ll*x*sum%mod;
		x=1ll*x*x%mod;k>>=1;
	}return sum;
}
inline void NTT(int *A,int o,int mod,int n){
	for(int i=0;i<n;i++)if(i<R[i])swap(A[i],A[R[i]]);
	for(int i=1;i<n;i<<=1){
		int t0=qm(3,(mod-1)/(i<<1),mod),x,y;
		for(int j=0;j<n;j+=(i<<1)){
			int t=1;
			for(int k=0;k<i;k++,t=1ll*t0*t%mod){
				x=A[j+k];y=1ll*t*A[j+k+i]%mod;
				A[j+k]=(x+y)%mod;A[j+k+i]=(x-y+mod)%mod;
			}
		}
	}
	if(o==-1)reverse(A+1,A+n);
}
inline void mul(int *A,int *B,int mod,int n){
	NTT(A,1,mod,n);NTT(B,1,mod,n);
	for(int i=0;i<=n;i++)A[i]=1ll*A[i]*B[i]%mod;
	NTT(A,-1,mod,n);
	int t=qm(n,mod-2,mod);
	for(int i=0;i<=n;i++)A[i]=1ll*A[i]*t%mod;
}
inline ll ksc(ll x,ll k,ll mod){
	if(mod<=P[1])return x%mod*k%mod;
	ll sum=0;
	if(x>=mod)x%=mod;if(k>=mod)k%=mod;
	while(k){
		if(k&1)sum=(sum+x)%mod;
		x=(x+x)%mod;k>>=1;
	}return sum;
}
inline int CRT(int x,int y){
	ll lcm=1ll*P[0]*P[1];
	int inv1=qm(P[1],P[0]-2,P[0]),inv2=qm(P[0],P[1]-2,P[1]);
	return (ksc(1ll*inv1*P[1],x,lcm)+ksc(1ll*inv2*P[0],y,lcm))%lcm%M;
}
inline void solve(int l,int r,int t){
	if(l==r){f[t][0]=1;f[t][1]=a[l]%M;return ;}
	int mid=(l+r)>>1,m=r-l+1,n,L;
	for(n=1,L=0;n<=m;n<<=1)L++;
	int A[2][n+5],B[2][n+5];
	memset(A,0,sizeof(A));memset(B,0,sizeof(B));
	solve(l,mid,t+1);
	for(int i=0;i<=mid-l+1;i++)A[0][i]=A[1][i]=f[t+1][i];
	
	solve(mid+1,r,t+1);
	for(int i=0;i<=r-mid;i++)B[0][i]=B[1][i]=f[t+1][i];

	for(int i=0;i<n;i++)R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
	mul(A[0],B[0],P[0],n);mul(A[1],B[1],P[1],n);

	for(int i=0;i<=m;i++)f[t][i]=CRT(A[0][i],A[1][i]);
}
int main(){
	freopen("pp.in","r",stdin);
	freopen("pp.out","w",stdout);
	int n,Q,x;
	scanf("%d%d",&n,&Q);
	for(int i=1;i<=n;i++)scanf("%d",&a[i]);
	solve(1,n,1);
	while(Q--)scanf("%d",&x),printf("%d
",f[1][x]);
	return 0;
}

原文地址:https://www.cnblogs.com/Yuzao/p/8508817.html