hdu 4906 3-idiots fft

题目链接

n个火柴棍取3个, 问能组成三角形的概率是多少。 kuangbin大神的博客写的很详细了..http://www.cnblogs.com/kuangbin/archive/2013/07/24/3210565.html

注意long long什么的就没问题了。

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define mem(a) memset(a, 0, sizeof(a))
typedef complex <double> cmx;
const double PI = acos(-1.0);
const int maxn = 400005;
cmx x[maxn];
int a[maxn/4];
ll num[maxn];
void change(cmx x[], int len) {
    int i, j, k;
    for(i = 1, j = len/2; i < len - 1; i++) {
        if(i < j)
            swap(x[i], x[j]);
        k = len / 2;
        while(j >= k) {
            j -= k;
            k /= 2;
        }
        if(j < k)
            j += k;
    }
}
void fft(cmx x[], int len, int on) {
    change(x, len);
    for(int i = 2; i <= len; i <<= 1) {
        cmx wn(cos(-on * 2 * PI/i), sin(-on * 2 * PI/i));
        for(int j = 0; j < len; j += i) {
            cmx w(1, 0);
            for(int k = j; k < j + i/2; k++) {
                cmx u = x[k];
                cmx v = x[k + i/2]*w;
                x[k] = u + v;
                x[k+i/2] = u - v;
                w *= wn;
            }
        }
    }
    if(on == -1) {
        for(int i = 0; i < len; i++)
            x[i] /= len;
    }
}
int main()
{
    int t, n;
    cin>>t;
    while (t--) {
        cin>>n;
        mem(num);
        int maxx = 0;
        for (int i = 0; i < n; i++) {
            scanf("%d", a + i);
            num[a[i]]++;
            maxx = max(maxx, a[i]);
        }
        sort(a, a + n);
        int len = 1;
        maxx++;
        while (len < 2*maxx) {
            len <<= 1;
        }
        for (int i = 0; i < maxx; i++) {
            x[i] = cmx(num[i], 0);
        }
        for (int i = maxx; i < len; i++) {
            x[i] = cmx(0, 0);
        }
        fft(x, len, 1);
        for (int i = 0; i < len; i++) {
            x[i] *= x[i];
        }
        fft(x, len, -1);
        for (int i = 0; i < len; i++) {
            num[i] = (ll)(x[i].real()+0.5);
        }
        for (int i = 0; i < n; i++) {
            num[a[i]+a[i]]--;
        }
        for (int i = 0; i < len; i++) {
            num[i] /= 2;
        }
        for (int i = 1; i < len; i++) {
            num[i] += num[i-1];
        }
        ll ans = 0;
        for (int i = 0; i < n; i++) {
            ans += num[len-1] - num[a[i]];
            ans -= 1LL * (n-i-1) * i;
            ans -= 1LL * (n-i-1) * (n-i-2) / 2;
        }
        ans -= 1LL * n * (n-1);
        ll sum = 1LL * n * (n-1) * (n-2) / 6;
        printf("%.7f
", 1.0*ans/sum);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/yohaha/p/5920949.html