BZOJ 3513: [MUTC2013]idiots FFT

Description

给定n个长度分别为a_i的木棒,问随机选择3个木棒能够拼成三角形的概率。

Input

第一行T(T<=100),表示数据组数。
接下来若干行描述T组数据,每组数据第一行是n,接下来一行有n个数表示a_i。
3≤N≤10^5,1≤a_i≤10^5

Output

T行,每行一个整数,四舍五入保留7位小数。

题解: 

三角形的三条边要满足最小边与次小边之和要小于最长边之和
 
令 $f_{i}$ 表示两边之和为 $i$ 的数量.
 
那么合法的三角形数量应为 $sum_{i=1}^{Max}f_{i} imes g_{i-1}$ ($g_{i}$ 表示长度小于等于 $i$ 的数量)
 
然而这样做其实十分麻烦,因为 $g_{i-1}$ 中与 $f_{i}$ 中是会有重复元素的
 
我们变一下,令 $g_{i}$ 表示长度大于等于 $i$ 的数量
 
那么不合法的情况为 $sum_{i=1}^{Max}f_{i} imes g_{i}$,可以用总数量减掉不合法数量来求合法数量
 
构造生成函数 $A=sum_{i=1}^{Max}a_{i}x^i$, $a_{i}$ 表示长度为 $i$ 的边有多少个
 
那么 $f=A^2$ 就是两边结合的情况,用 $FFT$ 来加速
 
要注意当 $i$ 为偶数时,相同的边也会贡献一次,所以要先减掉这些相同边
 
然后发现我们这么结合时有序的,而实际上边应该是无序的,所以还需要 $/2$
 
得到正确的 $f$ 后一次枚举每一个 $i$,与 $g_{i}$ 结合即可
#include<bits/stdc++.h> 
#define setIO(s) freopen(s".in","r",stdin) 
#define maxn 400003 
#define ll long long 
using namespace std;  
namespace IO 
{
    inline int read()
    {
        int ans=0;
        char ch=getchar();
        while(!isdigit(ch))ch=getchar();
        while(isdigit(ch))ans=(ans<<3)+(ans<<1)+(ch^48),ch=getchar();
        return ans;
    }
};  
const double pi=acos(-1.0); 
struct cpx
{ 
    double x,y; 
    cpx(double a=0,double b=0){ x=a,y=b; } 
    cpx operator+(const cpx b) { return cpx(x+b.x, y+b.y); } 
    cpx operator-(const cpx b) { return cpx(x-b.x, y-b.y); } 
    cpx operator*(const cpx b) { return cpx(x*b.x-y*b.y,x*b.y+y*b.x); }
}A[maxn],B[maxn];      
inline void FFT(cpx *a,int n,int flag)
{
    for(int i=0,k=0;i<n;++i) 
    {
        if(i>k) swap(a[i], a[k]); 
        for(int j=(n>>1);(k^=j)<j;j>>=1); 
    } 
    for(int mid=1;mid<n;mid<<=1)
    {
        cpx wn(cos(pi/mid), flag*sin(pi/mid)),x,y; 
        for(int i=0;i<n;i+=(mid<<1))
        {
            cpx w(1,0); 
            for(int j=0;j<mid;++j) 
            {
                x=a[i+j],y=w*a[i+j+mid];       
                a[i+j]=x+y, a[i+j+mid]=x-y;    
                w=w*wn;    
            }
        }
    }
    if(flag==-1) for(int i=0;i<n;++i) a[i].x/=(double)n;    
}
int f[maxn], arr[maxn], g[maxn]; 
ll answer[maxn];       
inline void solve() 
{
    int n,Max=0,len; 
    n=IO::read();       
    for(int i=1;i<=n;++i) arr[i]=IO::read(), ++f[arr[i]], Max=max(Max, arr[i]);    
    for(int i=Max;i>=1;--i) g[i]=f[i]+g[i+1];           
    for(int i=1;i<=Max;++i) A[i].x=(double)f[i];                
    for(len=1;len<=(Max<<1);len<<=1); 
    FFT(A,len,1);   
    for(int i=0;i<len;++i) A[i]=A[i]*A[i];   
    FFT(A,len,-1);           
    for(int i=1;i<len;++i) 
    {
        answer[i]=(ll)(A[i].x+0.5);     
        if(!answer[i]) continue;   
        if(i%2==0) answer[i]-=f[i>>1];    
        answer[i]>>=1;   
    }   
    ll up,down; 
    up=down=(ll)n*(n-1)*(n-2)/6;        
    for(int i=0;i<=len;++i) up-=answer[i]*(ll)g[i];   
    printf("%.7f
",(double)up/(double)down);       
    memset(A,0,sizeof(A)), memset(f,0,sizeof(f)), memset(g,0,sizeof(g));    
}
int main()
{ 
    // setIO("input"); 
    int T; 
    T=IO::read();    
    while(T--) solve(); 
    return 0; 
}

  

原文地址:https://www.cnblogs.com/guangheli/p/11168104.html