[SDOI2017]遗忘的集合

description

Luogu
BZOJ
你原来有一个集合(S),集合中的元素都(le n),
并且你计算出了对于任意(xle n)(x)表示成(S)中元素之和的方案数(f_x)
现在你遗忘了原来的集合,只知道(f_x\%p)
求所有可能的集合(S)中字典序最小的解。

data range

[nle 2^{16},10^6le p<2^{30} ]

solution

关于和的方案数的问题,自然想到使用生成函数。我才不会说是为了搞生成函数做的

(F(x)=sum_{i=0}^{infty}f_ix^i),(c_i)表示大小为(i)的元素是否存在((0/1)),那么有

[F(x)=frac{1}{prod_{i=1}^{infty}(1-c_ix^i)} ]

乘积求(ln)的套路十分常见。另一道题

[ln(F(x))=-sum_{i=1}^{infty}ln(1-c_ix^i) ]

由于(ln(1-x^i)=-sum_{j=1}^{infty}frac{1}{j}x^{ij}),而由(c_i)的定义(c_i^k=c_i),于是

[egin{aligned} ln(F(x))&=sum_{i=1}^{infty}sum_{j=1}^{infty}frac{1}{j}c_ix^{ij} \ &=sum_{i=1}^{n}sum_{j=1}^{lfloorfrac{n}{i} floor}frac{1}{j}c_ix^{ij}\ &=sum_{i=1}^{n}sum_{j|i}frac{1}{frac{i}{j}}c_jx_i \ &=sum_{i=1}^{n}frac{x^i}{i}sum_{j|i}c_jj\ end{aligned} ]

莫比乌斯反演即可求出(c_j)

Code

实现需要(MTT)比较难写

#include<bits/stdc++.h>
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<iomanip>
#include<cstring>
#include<complex>
#include<vector>
#include<cstdio>
#include<string>
#include<bitset>
#include<cassert>
#include<ctime>
#include<cmath>
#include<queue>
#include<stack>
#include<map>
#include<set>
#define FL "ex_unknown"
#define mp make_pair
#define pb push_back
#define fi first
#define se second
#define RG register
using namespace std;
typedef unsigned long long ull;
typedef vector<int>VI;
typedef long long ll;
typedef double dd;
const dd pi=acos(-1);
const dd eps=1e-6;
const int N=1e6+10;
inline ll read(){
  RG ll data=0,w=1;RG char ch=getchar();
  while(ch!='-'&&(ch<'0'||ch>'9'))ch=getchar();
  if(ch=='-')w=-1,ch=getchar();
  while(ch<='9'&&ch>='0')data=data*10+ch-48,ch=getchar();
  return data*w;
}
inline void file(){
  freopen("a.in","r",stdin);
  freopen("a.out","w",stdout);
}

int n,mod,f[N],g[N],inv[N],len;
int pri[N],mu[N];bool vis[N];
struct point{dd r,i;}w[N][2];
point operator +(point a,point b){return (point){a.r+b.r,a.i+b.i};}
point operator -(point a,point b){return (point){a.r-b.r,a.i-b.i};}
point operator *(point a,point b){
  return (point){a.r*b.r-a.i*b.i,a.r*b.i+a.i*b.r};
}
inline void print(int *a,int n){
  for(int i=0;i<n;i++)printf("%d ",a[i]);puts("");
}
const int S=1<<19;
inline void upd(int &a,int b){a+=b;if(a>=mod)a-=mod;}
inline void dec(int &a,int b){a-=b;if(a<0)a+=mod;}
inline int poww(int a,int b){
  int ret=1;
  for(;b;b>>=1,a=1ll*a*a%mod)
    if(b&1)ret=1ll*ret*a%mod;
  return ret;
}
inline void init(){
  vis[1]=mu[1]=inv[0]=inv[1]=1;
  for(int i=0;i<=S;i++){
    w[i][0]=w[i][1]=(point){cos(2*pi*i/S),sin(2*pi*i/S)};
    w[i][1].i=-w[i][1].i;
  }
  for(int i=2;i<N;i++)inv[i]=mod-1ll*(mod/i)*inv[mod%i]%mod;
  for(int i=2;i<N;i++){
    if(!vis[i])pri[++pri[0]]=i,mu[i]=-1;
    for(int j=1;j<=pri[0]&&1ll*i*pri[j]<N;j++){
      vis[i*pri[j]]=1;mu[i*pri[j]]=mod-mu[i];
      if(i%pri[j]==0){mu[i*pri[j]]=0;break;}
    }
  }
}

inline void FFT(point *a,int n,int opt){
  static int l,r[N];for(l=0;(1<<l)<n;l++);n=(1<<l);
  for(int i=0;i<n;i++)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
  for(int i=0;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
  for(int i=2;i<=n;i<<=1)
    for(int j=0;j<n;j+=i)
      for(int k=j;k<j+(i>>1);k++){
	point x=a[k+(i>>1)]*w[1ll*(k-j)*S/i][opt==-1];
	a[k+(i>>1)]=a[k]-x;a[k]=a[k]+x;
      }
  if(opt==-1)for(int i=0;i<n;i++)a[i].r/=n;
}
inline void conv(int *f,int *g,int *h,int n){
  static point a[N],b[N],c[N],d[N],s1[N],s2[N],s3[N];
  int q=sqrt(mod)+1;
  for(int i=0;i<(n<<1);i++)
    s1[i]=s2[i]=s3[i]=a[i]=b[i]=c[i]=d[i]=(point){0,0};
  for(int i=0;i<n;i++){
    a[i].r+=f[i]/q;b[i].r+=f[i]%q;c[i].r+=g[i]/q;d[i].r+=g[i]%q;
  }
  FFT(a,n<<1,1);FFT(b,n<<1,1);FFT(c,n<<1,1);FFT(d,n<<1,1);
  for(int i=0;i<(n<<1);i++)
    s1[i]=a[i]*c[i],s2[i]=(a[i]*d[i])+(b[i]*c[i]),s3[i]=b[i]*d[i];
  FFT(s1,n<<1,-1);FFT(s2,n<<1,-1);FFT(s3,n<<1,-1);
  for(int i=0;i<(n<<1);i++)h[i]=0;
  for(int i=0;i<(n<<1);i++){
    upd(h[i],((ll)(s1[i].r+0.5))*q%mod*q%mod);
    upd(h[i],((ll)(s2[i].r+0.5))*q%mod);
    upd(h[i],((ll)(s3[i].r+0.5))%mod);
  }
}

void getinv(int *f,int *g,int n){
  static int a[N],b[N];
  if(n==1){g[0]=poww(f[0],mod-2);return;}getinv(f,g,n>>1);
  for(int i=0;i<(n<<1);i++)a[i]=b[i]=0;
  for(int i=0;i<n;i++)a[i]=f[i],b[i]=g[i];
  conv(a,b,a,n);
  for(int i=0;i<n;i++)if(a[i])a[i]=mod-a[i];upd(a[0],2);
  for(int i=n;i<(n<<1);i++)a[i]=0;
  for(int i=0;i<n;i++)b[i]=g[i];
  conv(a,b,a,n);
  for(int i=0;i<n;i++)g[i]=a[i];
}
inline void getdao(int *f,int *g,int n){
  for(int i=0;i<n-1;i++)g[i]=1ll*(i+1)*f[i+1]%mod;
}
inline void getjifen(int *f,int *g,int n){
  g[0]=0;for(int i=1;i<n;i++)g[i]=1ll*inv[i]*f[i-1]%mod;
}

inline void getln(int *f,int *g,int n){
  static int a[N],b[N];
  for(int i=0;i<(n<<1);i++)a[i]=b[i]=0;
  getdao(f,a,n);getinv(f,b,n);
  conv(a,b,a,n);getjifen(a,g,n);
}
void getexp(int *f,int *g,int n){
  static int a[N],b[N];
  if(n==1){g[0]=1;return;}getexp(f,g,n>>1);
  for(int i=0;i<(n<<1);i++)a[i]=b[i]=0;
  for(int i=0;i<n;i++)a[i]=g[i];
  getln(g,b,n);for(int i=0;i<n;i++)b[i]=(f[i]-b[i]+mod)%mod;upd(b[0],1);
  conv(a,b,a,n);for(int i=0;i<n;i++)g[i]=a[i];
}

int cal[N],top;
int main()
{
  n=read()+1;mod=read();for(len=1;len<n;len<<=1);init();
  f[0]=1;for(int i=1;i<n;i++)f[i]=read();
  getln(f,g,len);
  memset(f,0,sizeof(f));
  for(int i=0;i<n;i++)f[i]=1ll*g[i]*i%mod;
  memset(g,0,sizeof(g));
  for(int i=1;i<n;i++)
    for(int j=1;j*i<n;j++)
      (g[i*j]+=1ll*f[i]*mu[j]%mod)%=mod;
  for(int i=1;i<n;i++)if(g[i])cal[++top]=i;
  printf("%d
",top);
  for(int i=1;i<=top;i++)printf("%d ",cal[i]);puts("");
  return 0;
}

原文地址:https://www.cnblogs.com/cjfdf/p/10146199.html