[gym102220I]Temperature Survey

(为了方便,以下记$a_{0}=0,a_{n+1}=n$​​,并将$n$​​加上1)

构造一个$n$行的网格图,从上到下第$i$行有$a_{i}$个格子,格子左对齐

记第$i$行第$j$个格子为$(i,j)$​,格子集合${(i,j)mid i_{1}le ile i_{2}$且$j_{1}le jle j_{2}}$为$([i_{1},i_{2}],[j_{1},j_{2}])$

此时,考虑一条从$(1,1)$​​到$(n,a_{n})$​​的路径(只能向下或向右),令$b_{i}$​​为路径中从第$i$​​行到第$i+1$​​行时的格子编号,不难发现这样的路径与合法的$b_{i}$​​一一对应,那么不妨统计路径数

关于路径数,显然可以dp解决,即令$f_{i,j}$表示从$(1,1)$到$(i,j)$的路经数,则$f_{i,j}=f_{i,j-1}+f_{i-1,j}$

使用分治优化dp转移,当分治到区间$[l,r]$​​​时,需要根据$f_{[l,r],a_{l-1}}$​​​的值求出$f_{r,(a_{l-1},a_{r}]}$​​​​​的值

具体的,分治过程如下——

1.令$mid=lfloorfrac{l+r}{2} floor$​​​​​,递归左区间$[l,mid]$​​​​​,根据$f_{[l,mid],a_{l-1}}$​​​​​的值求出$f_{mid,(a_{l-1},a_{mid}]}$​​​​​的值

2.根据$f_{(mid,r],a_{l-1}}$​​​​​​​和$f_{mid,(a_{l-1},a_{mid}]}$​​​​​​​的值,快速求出$f_{(mid,r],a_{mid}}$​​​​​​​​和$f_{r,(a_{l-1},a_{mid}]}$​​​​​的值

3.递归右区间$(mid,r]$​,根据$f_{(mid,r],a_{mid}}$​​的值求出$f_{r,(a_{mid},a_{r}])}$​的值

(其中,第二步显然可以写成卷积的形式,直接ntt即可)

边界条件:若$l>r$直接退出,若$l=r$令$f_{l,(a_{l-1},a_{r}]}=f_{l,a_{l-1}}$并退出

另外,有一些细节问题:

1.为了避免位置不合法,需要保证$a_{l-1}<a_{l}$,若不满足则不断增加$l$即可

2.关于$f$的存储,只需要用两个数组,分别存储当前每一行/列递归到最右边的一列/行的$f$值即可

总复杂度为$o(nlog^{2}n)$​​,可以通过​

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 #define N 200005
  4 #define mod 998244353
  5 #define ll long long
  6 #define vi vector<int>
  7 vi v,vl,vu;
  8 int t,n,fac[N<<1],inv[N<<1],rev[N<<3],a[N],f[N],ans[N];
  9 int c(int n,int m){
 10     return (ll)fac[n]*inv[m]%mod*inv[n-m]%mod;
 11 }
 12 int qpow(int n,int m){
 13     int s=n,ans=1;
 14     while (m){
 15         if (m&1)ans=(ll)ans*s%mod;
 16         s=(ll)s*s%mod;
 17         m>>=1;
 18     }
 19     return ans;
 20 }
 21 void ntt(vi &a,int n,int p){
 22     for(int i=0;i<(1<<n);i++)
 23         if (i<rev[i])swap(a[i],a[rev[i]]);
 24     for(int i=2;i<=(1<<n);i<<=1){
 25         int s=qpow(3,(mod-1)/i);
 26         if (p)s=qpow(s,mod-2);
 27         for(int j=0;j<(1<<n);j+=i)
 28             for(int k=0,ss=1;k<(i>>1);k++,ss=(ll)ss*s%mod){
 29                 int x=a[j+k],y=(ll)a[j+k+(i>>1)]*ss%mod;
 30                 a[j+k]=(x+y)%mod;
 31                 a[j+k+(i>>1)]=(x+mod-y)%mod;
 32             }
 33     }
 34     if (p){
 35         int s=qpow((1<<n),mod-2);
 36         for(int i=0;i<(1<<n);i++)a[i]=(ll)a[i]*s%mod;
 37     }
 38 }
 39 vi mul(vi a,vi b){
 40     int n=0,m=a.size()+b.size()-1;
 41     while ((1<<n)<m)n++;
 42     for(int i=0;i<(1<<n);i++)rev[i]=(rev[i>>1]>>1)+((i&1)*(1<<n)/2);
 43     while (a.size()<(1<<n))a.push_back(0);
 44     while (b.size()<(1<<n))b.push_back(0);
 45     ntt(a,n,0);
 46     ntt(b,n,0);
 47     for(int i=0;i<(1<<n);i++)a[i]=(ll)a[i]*b[i]%mod;
 48     ntt(a,n,1);
 49     return a;
 50 }
 51 void calc(int l,int r){
 52     while ((l<=r)&&(a[l-1]==a[l]))l++;
 53     if (l>r)return;
 54     if (l==r){
 55         for(int i=a[l-1]+1;i<=a[r];i++)ans[i]=f[l];
 56         return;
 57     }
 58     int mid=(l+r>>1);
 59     calc(l,mid);
 60     vl.clear(),vu.clear();
 61     for(int i=mid+1;i<=r;i++)vl.push_back(f[i]);
 62     for(int i=a[l-1]+1;i<=a[mid];i++)vu.push_back(ans[i]);
 63     
 64     v.clear();
 65     for(int i=0;i<r-mid;i++)v.push_back(c(i+a[mid]-a[l-1]-1,i));
 66     v=mul(v,vl);
 67     for(int i=0;i<r-mid;i++)f[i+mid+1]=v[i];
 68     
 69     v.clear();
 70     for(int i=0;i<r-mid+a[mid]-a[l-1]-1;i++)v.push_back(fac[i]);
 71     for(int i=0;i<r-mid;i++)vl[i]=(ll)vl[i]*inv[r-mid-1-i]%mod;
 72     v=mul(v,vl);
 73     for(int i=0;i<a[mid]-a[l-1];i++)ans[i+a[l-1]+1]=(ll)v[i+r-mid-1]*inv[i]%mod;
 74     
 75     v.clear();
 76     for(int i=0;i<a[mid]-a[l-1];i++)v.push_back(c(i+r-mid-1,i));
 77     v=mul(v,vu);
 78     for(int i=0;i<a[mid]-a[l-1];i++)ans[i+a[l-1]+1]=(ans[i+a[l-1]+1]+v[i])%mod;
 79     
 80     v.clear();
 81     for(int i=0;i<r-mid+a[mid]-a[l-1]-1;i++)v.push_back(fac[i]);
 82     for(int i=0;i<a[mid]-a[l-1];i++)vu[i]=(ll)vu[i]*inv[a[mid]-a[l-1]-1-i]%mod;
 83     v=mul(v,vu);
 84     for(int i=0;i<r-mid;i++)f[i+mid+1]=(f[i+mid+1]+(ll)v[i+a[mid]-a[l-1]-1]*inv[i])%mod;
 85     calc(mid+1,r);
 86 }
 87 int main(){
 88     fac[0]=inv[0]=inv[1]=1;
 89     for(int i=1;i<(N<<1);i++)fac[i]=(ll)fac[i-1]*i%mod;
 90     for(int i=2;i<(N<<1);i++)inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
 91     for(int i=1;i<(N<<1);i++)inv[i]=(ll)inv[i-1]*inv[i]%mod;
 92     scanf("%d",&t);
 93     while (t--){
 94         scanf("%d",&n);
 95         for(int i=1;i<=n;i++)scanf("%d",&a[i]);
 96         n++,a[n]=n;
 97         for(int i=1;i<=n;i++)f[i]=ans[i]=0;
 98         f[1]=1;
 99         calc(1,n);
100         printf("%d
",ans[n]);
101     } 
102     return 0;
103 } 
View Code
原文地址:https://www.cnblogs.com/PYWBKTDA/p/15057747.html