LOJ 6436 「PKUSC2018」神仙的游戏——思路+卷积

题目:https://loj.ac/problem/6436

看题解才会。

有长为 i 的 border ,就是有长为 n-i 的循环节。

考虑如果 x 位置上是 0 、 y 位置上是 1 ,那么长度是 | x-y | 的约数的循环节都不可行,因为在该循环节中, x 和 y 处在 “应该相等” 的地位。

最后一个部分分是暴力枚举 0 和 1 来预处理出一个 h[ i ] 表示长度是 i 的约数的循环节不可行。然后枚举循环节的长度 i ,再枚举 i 的倍数看看有没有 “不可行” 的。这样是 nlogn 。

考虑用卷积来优化求 h[ ] 。就是想找 “位置差一定” 的一个 0 和一个 1 ;令 ( F(x) = sum [ s[i]=='0' ] x^i ) ,( G(x) = sum [ s[i]=='1' ] x^i ) ,翻转其中一个,做卷积即可。

感觉很卡常。没有 ' ? ' 的那个子任务,自己必须特判(用 kmp 做)才能不超时。

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int N=5e5+5,M=(1<<20)+5,mod=998244353;
int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}
int pw(int x,int k)
{int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;}

int n,f[M],g[M],len,r[M];char s[N];
int wn[M],wn2[M],nxt[N];
void ntt_init()
{
  for(int R=2;R<=len;R<<=1)
    {
      wn[R]=pw(3,(mod-1)/R);
      wn2[R]=pw(3,(mod-1)-(mod-1)/R);
    }
}
void ntt(int *a,bool fx)
{
  for(int i=0;i<len;i++)
    if(i<r[i])swap(a[i],a[r[i]]);
  for(int R=2;R<=len;R<<=1)
    {
      int Wn=(fx?wn2[R]:wn[R]);
      for(int i=0,m=R>>1;i<len;i+=R)
    for(int j=0,w=1;j<m;j++,w=(ll)w*Wn%mod)
      {
        int x=a[i+j],y=(ll)w*a[i+m+j]%mod;
        a[i+j]=upt(x+y); a[i+m+j]=upt(x-y);
      }
    }
  if(!fx)return; int inv=pw(len,mod-2);
  for(int i=0;i<len;i++)a[i]=(ll)a[i]*inv%mod;
}
void kmp()
{
  for(int i=n;i;i--)s[i]=s[i-1];
  for(int i=2;i<=n;i++)
    {
      int cr=nxt[i-1];
      while(cr&&s[cr+1]!=s[i])cr=nxt[cr];
      if(s[cr+1]==s[i])nxt[i]=cr+1;
      else nxt[i]=0;
    }
  ll ans=(ll)n*n;
  int cr=nxt[n];
  while(cr)
    {
      ans^=(ll)cr*cr; cr=nxt[cr];
    }
  printf("%lld
",ans);
}
int main()
{
  scanf("%s",s); n=strlen(s); bool chk=0;
  for(int i=0;i<n;i++)
    {
      f[i]=(s[i]=='0'); g[n-1-i]=(s[i]=='1');
      if(s[i]=='?')chk=1;
    }
  if(!chk){ kmp();return 0;}
  for(len=1;len<n<<1;len<<=1);
  for(int i=0,j=len>>1;i<len;i++)
    r[i]=(r[i>>1]>>1)+((i&1)?j:0);
  ntt_init();
  ntt(f,0); ntt(g,0);
  for(int i=0;i<len;i++)f[i]=(ll)f[i]*g[i]%mod;
  ntt(f,1);
  for(int i=0;i<n;i++)g[i]=((f[n-1-i]||f[n-1+i])?1:0);
  ll ans=(ll)n*n;
  for(int i=1;i<n;i++)
    {
      int x=n-i; bool fg=0;
      for(int j=x;j<n;j+=x)
    if(g[j]){fg=1;break;}
      if(!fg)ans^=(ll)i*i;
    }
  printf("%lld
",ans);
  return 0;
}
原文地址:https://www.cnblogs.com/Narh/p/10899816.html