[luogu5577]算力训练

(以下以$B$为进制,$m$为幂次,$n=B^{m}$)

定义$oplus$为$k$进制下不进位加法,$otimes$为$oplus$卷积

令$f_{i,j}$表示前$i$个数的$oplus$之和为$j$的子序列数,再令$g_{i,j}=[j=0]+[j=a_{i}]$($a_{i}$为给定序列),则$f_{i}=f_{i-1}otimes g_{i}$

类似uoj272,但以该题复杂度计算时间复杂度显然是不对的

根据$g_{i,j}$的式子,不难发现将其做了DFT后的结果显然恰好就是矩阵$A$的第0行加第$a_{i}$行

直接考虑最终将每一个$g_{i}$的DFT对应位置相乘后第$j$个位置的值,即$ans_{j}=prod_{i=1}^{n}(A_{0,j}+A_{a_{i},j})$

不难发现$A_{i,j}=omega^{k}$(指存在$k$,其中$0le k<B$),我们如果能知道$A_{a_{i},j}$中每一个$k$出现了多少次,再使用快速幂来计算,就可以做到$o(Blog_{2}n)$的复杂度了

更具体的,用$f_{i,j}$表示有多少个$k$满足$A_{a_{k},i}=omega^{j}$,答案即$ans_{j}=prod_{i=0}^{B-1}(1+omega^{i})^{f_{j,i}}$($A_{0,j}=1$)

如何求出$f_{i,j}$,其并不容易递推,考虑这样一个构造:对于每一个$i$,求出$f_{i}$这个长度为$B$的序列DFT的结果,再用IDFT即求出$f_{i,j}$

考虑这个DFT结果的第$k$个数,即$sum_{l=0}^{B-1}f_{i,l}A_{l,k}=sum_{l=0}^{B-1}f_{i,l}(omega^{l})^{k}=sum_{l=0}^{n-1}(A_{a_{l},i})^{k}$

再构造一个$C_{i}=sum_{j=0}^{n-1}[a_{j}=i]$,那么对$C$做DFT后的第$i$项即为$sum_{j=0}^{n-1}C_{j}A_{j,i}=sum_{j=0}^{n-1}A_{a_{j},i}$

其实这两个式子很接近,只需要让每一个$A_{i,j}$都变为其$k$次幂即可

注意到我们能快速计算DFT依赖于第二个性质($A_{i,j}=A_{lfloorfrac{i}{B} floor,lfloorfrac{j}{B} floor}A_{i mod B,j mod B}$),而在这个性质下,让每一个$A_{i,j}$都变为其$k$次幂等价于构造$A$左上角的$B imes B$的部分为$A_{i,j}=omega^{ijk}$

具体来说,分为以下四个步骤:

1.对$C$做$B$次DFT,每一次DFT的$A$矩阵不同,第$k(0le k<B)$次DFT的$A_{i,j}=omega^{ijk}(0le i,j<B)$,这里的时间复杂度是$o(mB^{m+4})$(由于两数相乘复杂度也为$o(B^{2})$,一次DFT复杂度为$o(mB^{m+3})$)

2.对于第一步中第$i$次DFT结果的第$j$项,恰好就是$f_{j}$(这是一个长为$B$的数列)做DFT后的第$i$项,换言之我们得到了每一个$f_{j}$做DFT后的结果,做$n$次$B^{4}$的IDFT即可,复杂度为$o(B^{m+4})$

3.得到$f_{i,j}$后,直接根据$ans_{j}=prod_{i=0}^{B-1}(1+omega^{i})^{f_{j,i}}$计算出$ans_{j}$,通过快速幂来优化,那么求一个$ans_{j}$的时间复杂度为$o(B^{3}log_{2}n)$,总复杂度即$o(B^{m+3}log_{2}n)$

4.求出$ans_{j}$再做一次IDFT即为答案,时间复杂度为$o(mB^{m+3})$

最终总复杂度为$o((m+log_{2}n)B^{m+4})$,可以通过

另外关于数值的表示,在平常递归时先使用$sum_{i=0}^{B-1}a_{i}omega^{i}$来表示,根据$omega^{B}=1$可以对其封闭运算,当我们可以证明某一个数为实数且需要得到该值时,通过如下方式降幂,然后$omega^{0}$系数即为答案

降幂的需要对$B$分类:

1.若$B=5$,将$omega^{4}$利用$sum_{i=0}^{4}omega^{i}=frac{1-omega^{5}}{1-omega}=0$来降幂

2.若$B=6$,根据$omega^{3}=-1$来降幂,首先得到$omega^{i}=-omega^{i-frac{B}{2}}$来将$i$次项($frac{B}{2}le i<B$)降幂,再利用$sum_{i=0}^{2}(-omega)^{i}=frac{1-(-omega)^{3}}{1+omega}=0$来降$omega^{2}$

关于这个降幂的正确性(也就是之后高次项不能将虚数部分抵消)不会证,但可以发现其等价于不能再次进行降幂,之后(观察)发现找不到继续降的方式,即合法

(注意输入是$B$进制)

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 #define N 100005
  4 #define M 7
  5 #define maxB 6
  6 #define mod 998244353
  7 int n,m,x,B,base[N][M];
  8 struct Complex{
  9     int a[maxB];
 10     Complex(){
 11         memset(a,0,sizeof(a));
 12     }
 13     Complex(int x){
 14         memset(a,0,sizeof(a));
 15         a[0]=x;
 16     }
 17     Complex(int x,int y){
 18         memset(a,0,sizeof(a));
 19         a[0]=x,a[1]=y;
 20     }
 21     Complex operator + (const Complex &k)const{
 22         Complex o;
 23         for(int i=0;i<B;i++)o.a[i]=(a[i]+k.a[i])%mod;
 24         return o;
 25     }
 26     Complex operator * (const Complex &k)const{
 27         Complex o;
 28         for(int i=0;i<B;i++)
 29             for(int j=0;j<B;j++)o.a[(i+j)%B]=(o.a[(i+j)%B]+1LL*a[i]*k.a[j])%mod;
 30         return o;
 31     }
 32     int get(){
 33         if (B==5)return (a[0]-a[4]+mod)%mod;
 34         return ((a[0]-a[3]+mod)%mod-(a[2]-a[5]+mod)%mod+mod)%mod;
 35     }
 36 }inv,A[maxB][maxB],AA[maxB][maxB],invA[maxB][maxB],a[N],b[maxB][N],f[N][maxB];
 37 int read(){
 38     int x=0;
 39     char c=getchar();
 40     while ((c<'0')||(c>'9'))c=getchar();
 41     while ((c>='0')&&(c<='9')){
 42         x=x*B+c-'0';
 43         c=getchar();
 44     }
 45     return x;
 46 }
 47 Complex pow(Complex n,int m){
 48     Complex s=n,ans=Complex(1);
 49     while (m){
 50         if (m&1)ans=ans*s;
 51         s=s*s;
 52         m>>=1;
 53     }
 54     return ans;
 55 }
 56 void DFT(Complex *a){
 57     Complex aa[B];
 58     for(int i=0,s=1;i<m;i++,s*=B)
 59         for(int j=0;j<n;j++)
 60             if (!base[j][i]){
 61                 for(int k=0;k<B;k++)aa[k]=Complex();
 62                 for(int k=0;k<B;k++)
 63                     for(int l=0;l<B;l++)aa[k]=aa[k]+a[j+l*s]*A[l][k];
 64                 for(int k=0;k<B;k++)a[j+k*s]=aa[k];
 65             }
 66 }
 67 void IDFT(Complex *a){
 68     Complex aa[B];
 69     for(int i=0,s=1;i<m;i++,s*=B)
 70         for(int j=0;j<n;j++)
 71             if (!base[j][i]){
 72                 for(int k=0;k<B;k++)aa[k]=Complex();
 73                 for(int k=0;k<B;k++)
 74                     for(int l=0;l<B;l++)aa[k]=aa[k]+a[j+l*s]*invA[l][k];
 75                 for(int k=0;k<B;k++)a[j+k*s]=aa[k];
 76             }
 77 }
 78 int main(){
 79     scanf("%d%d%d",&n,&B,&m);
 80     for(int i=0;i<n;i++){
 81         x=read();
 82         a[x]=a[x]+Complex(1);
 83     }
 84     n=1;
 85     for(int i=0;i<m;i++)n*=B;
 86     for(int i=0;i<n;i++){
 87         base[i][0]=i%B;
 88         for(int j=1;j<m;j++)base[i][j]=base[i/B][j-1];
 89     }
 90     inv=pow(Complex(B),mod-2);
 91     for(int i=0;i<B;i++)
 92         for(int j=0;j<B;j++){
 93             A[i][j]=Complex(1);
 94             AA[i][j]=pow(Complex(0,1),i*j);
 95             invA[i][j]=pow(Complex(0,1),B*B-i*j)*inv;
 96         }
 97     for(int i=0;i<B;i++){
 98         memcpy(b[i],a,sizeof(a));
 99         DFT(b[i]);
100         for(int j=0;j<B;j++)
101             for(int k=0;k<B;k++)A[j][k]=A[j][k]*AA[j][k];
102     }
103     for(int i=0;i<n;i++)
104         for(int j=0;j<B;j++)
105             for(int k=0;k<B;k++)f[i][j]=f[i][j]+b[k][i]*invA[k][j];
106     for(int i=0;i<n;i++){
107         a[i]=Complex(1);
108         for(int j=0;j<B;j++)a[i]=a[i]*pow(pow(Complex(0,1),j)+Complex(1),f[i][j].get());
109     }
110     IDFT(a);
111     for(int i=0;i<n;i++)printf("%d
",a[i].get());
112 } 
View Code
原文地址:https://www.cnblogs.com/PYWBKTDA/p/14498522.html