BZOJ 3509: [CodeChef] COUNTARI

Description

给定一个长度为(N)的数组(A[]),求有多少对(i, j, k(1leqslant i<j<k leqslant N))满足(A[k]-A[j]=A[j]-A[i])。

Solution

分块FFT。

每个暴力求需要(n)次FFT。

分块的话,FFT求块与块之间的,块内的暴力求。

复杂度(O(nsqrt {n}logn))

BZOJ上险些T了qwq...

Code

/**************************************************************
    Problem: 3509
    User: BeiYu
    Language: C++
    Result: Accepted
    Time:37220 ms
    Memory:7936 kb
****************************************************************/
 
#include <bits/stdc++.h>
using namespace std;
 
#define debug(a) cout<<#a<<"="<<a<<" "
#define mpr make_pair
#define r first
#define i second
 
typedef pair< double,double > pr;
typedef long long LL;
const int N = 1e5+50;
const int B = 2500;
const double Pi = M_PI;
 
pr operator + (const pr &a,const pr &b) { return mpr(a.r+b.r,a.i+b.i); }
pr operator - (const pr &a,const pr &b) { return mpr(a.r-b.r,a.i-b.i); }
pr operator * (const pr &a,const pr &b) { return mpr(a.r*b.r-a.i*b.i,a.r*b.i+a.i*b.r); }
 
int NN=65536;
 
void Rev(pr a[]) {
    for(int i=0,j=0;i<NN;i++) {
        if(i<j) swap(a[i],a[j]);
        for(int k=NN>>1;(j^=k)<k;k>>=1);
    }
}
void DFT(pr a[],int r=1) {
    Rev(a);
    for(int i=1;i<=NN;i<<=1) {
        pr wi=mpr(cos(2.0*Pi/i),r*sin(2.0*Pi/i));
        for(int j=0;j<NN;j+=i) {
            pr w=mpr(1.0,0.0);
            for(int k=j;k<j+i/2;k++) {
                pr x=a[k],y=w*a[k+i/2];
                a[k]=x+y,a[k+i/2]=x-y;
                w=w*wi;
            }
        }
    }if(r==-1) for(int i=0;i<NN;i++) a[i].r/=NN;
}
void FFT(pr a[],pr b[],pr c[]) {
    DFT(a,1),DFT(b,1);
    for(int i=0;i<NN;i++) c[i]=a[i]*b[i];
    DFT(c,-1);
}
 
inline int in(int x=0,char ch=getchar()) { while(ch>'9' || ch<'0') ch=getchar();
    while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();return x; }
 
int n;LL ans;
int a[N],b[N];
int bf[N],bd[N],tp[N];
pr x1[N],x2[N],x3[N];
 
int main() {
    n=in();
    for(int i=0;i<n;i++) a[i]=in();
     
    for(int i=0;i<n;i++) bd[a[i]]++;
     
    for(int j=0;j<n;j+=B) {
        for(int i=j;i<n && i<j+B;i++) bd[a[i]]--;
        //before and behind
        memset(x1,0,sizeof(x1)),memset(x2,0,sizeof(x2));
        for(int i=0;i<NN/2;i++) x1[i]=mpr(bf[i],0),x2[i]=mpr(bd[i],0);
         
        FFT(x1,x2,x3);
        for(int i=j;i<n && i<j+B;i++) tp[a[i]]++;
        for(int i=0;i<NN;i++) if(!(i&1)) ans+=(LL)tp[i/2]*(int)(x3[i].r+0.5);
        for(int i=j;i<n && i<j+B;i++) tp[a[i]]=0;
         
        for(int p=j;p<n && p<j+B;p++) for(int q=p+1;q<n && q<j+B;q++) {
            if(2*a[p]-a[q]>=0) ans+=bf[2*a[p]-a[q]];
            if(2*a[q]-a[p]>=0) ans+=bd[2*a[q]-a[p]];
        }
         
        for(int p=j;p<n && p<j+B;p++) {
            for(int q=j;q<p;q++) {
                if(2*a[q]-a[p]>=0) ans+=tp[2*a[q]-a[p]];
                tp[a[q]]++;
            }
            for(int q=j;q<p;q++) tp[a[q]]--;
        }
         
        for(int i=j;i<n && i<j+B;i++) bf[a[i]]++;
    }cout<<ans<<endl;
    return 0;
}

  

原文地址:https://www.cnblogs.com/beiyuoi/p/6549136.html