HDU1402 HDU4609 FFT快速DFT

原理具体内容可见算法导论第30章,很详细,部分线性代数知识

简单陈述:多项式可表示成点值表达式。次数界为n的多项式可以由n个点对唯一表示,证明可由矩阵行列式不为0,矩阵可逆证明。

设次数界为n的多项式A(x) = {(x0, y0), (x1,y1), ……,(xn-1, yn-1)},则另xi为x^n=1的n个复根w,则求出来的n维向量y成为多项式系数向量a的DFT变换。

FFT使用分治方法来快速计算DFT

HDU1402

#include <iostream>
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
using namespace std;
const double PI = acos(-1.0);
struct Complex
{
    double real, img;
    Complex(){real = img = 0;}
    Complex(double a, double b) : real(a), img(b) {}
    Complex operator+(const Complex a)
    {
        return Complex(real+a.real, img+a.img);
    }
    Complex operator-(const Complex a)
    {
        return Complex(real-a.real, img-a.img);
    }
    Complex operator*(const Complex a)
    {
        return Complex(real*a.real-img*a.img, real*a.img+img*a.real);
    }
};

int bit_revserve(int id, int len)
{
    int ans = 0, p;
    for(int i = 0; (1<<i) < len; i++)
    {
        ans <<= 1;
        if(id&(1<<i)) ans |= 1;
    }
    return ans;
}

Complex A[140000];
void FFT(Complex *a, int len, int DFT)
{
    for(int i = 0; i < len; i++)
        A[bit_revserve(i, len)] = a[i];
    for(int s = 1; (1<<s) <= len; s++)
    {
        int m = 1<<s;
        Complex wm(cos(DFT*2*PI/m), sin(DFT*2*PI/m));

        for(int k = 0; k < len; k += m)
        {
            Complex w(1,0);
            for(int j = 0; j < (m/2); j++)
            {
                Complex u = A[k+j];
                Complex t = A[k+j+m/2]*w;
                A[k+j] = u+t;
                A[k+j+m/2] = u-t;
                w = w*wm;
            }
        }
    }
    if(DFT == -1)
        for(int i = 0; i < len; i++)
            A[i].real /= len, A[i].img /= len;
    for(int i = 0; i < len; i++)
        a[i] = A[i];
}

char s[50050],ss[50050];
Complex a[140050];
Complex b[140050];
Complex c[140050];
int ans[140050];
int main()
{
    while(scanf("%s", s) != EOF)
    {
        memset(a, 0, sizeof(a));
        memset(b, 0, sizeof(b));
        int lena = strlen(s);
        int k = 0;
        scanf("%s", ss);
        int lenb = strlen(ss);
        while((1<<k) < lena) k++;
        while((1<<k) < lenb) k++;
        int len = 1 << (k+1);

        for(int i = 0; i < len; i++)
        {
            if(i < lena) a[i] = Complex(s[lena - i - 1] - '0', 0);
            else a[i] = Complex(0, 0);
            if(i < lenb) b[i] = Complex(ss[lenb - i - 1] - '0', 0);
            else b[i] = Complex(0, 0);
        }

        FFT(a, len, 1);
        FFT(b, len, 1);
        for(int i = 0; i < len; i++)
            c[i] = a[i]*b[i];
        FFT(c, len, -1);
        memset(ans, 0, sizeof(ans));
        for(int i = 0; i < len; i++)
        {
            ans[i] = int(c[i].real+0.5);
        }
        for(int i = 0; i < len; i++)
        {
            ans[i+1] += ans[i]/10;
            ans[i] %= 10;
        }
        int flag = 0;
        for(int i = len; i >= 0; i--)
        {
            if(ans[i]) printf("%d",ans[i]),flag = 1;
            else if(flag||i==0) printf("0");
        }
        puts("");
    }
}

 

 

HDU4609,上题比较裸,这题要多加思考,找到可以组成三角形的三个树枝的数量。进行一次FFT,也就是多项式相乘,那么对应x^n的系数就是用两个树枝的总长度为n的组合种数,这里的两个树枝是可以重复的,而且<a,b><b,a>这样选两个树枝是算两个,所以要去重,剪掉两个相同的树枝组成的n,然后除以二,保证无序性。然后按三角形最长的边枚举,两两之和大于当前枚举的边,剪掉两个都比它长的,一个比他长的,一个比他短的,一个就是它的三种情况就好了。

#include <iostream>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <math.h>
#include <algorithm>
using namespace std;
const int MAXN = 1e5*3+50;
const double PI = acos(-1.0);
typedef long long LL;
struct Complex
{
    double real,image;
    Complex(){real = image = 0;}
    Complex(double a, double b):real(a), image(b) {}
    Complex operator+(Complex a){ return Complex(real+a.real, image+a.image);}
    Complex operator-(Complex a){ return Complex(real-a.real, image-a.image);}
    Complex operator*(Complex a){ return Complex(real*a.real-image*a.image, real*a.image+a.real*image);}
};

int rev(int id, int len)
{
    int ret = 0;
    for(int i = 0; (1<<i) < len; i++)
    {
        ret <<= 1;
        if(id&(1<<i)) ret |= 1;
    }
    return ret;
}

Complex A[MAXN];

void FFT(Complex a[], int len, int DFT)
{
    for(int i = 0; i < len; i++)
        A[rev(i,len)] = a[i];
    for(int s = 1; (1<<s) <= len; s++)
    {
        int m = 1<<s;
        Complex wm(cos(PI*2*DFT/m), sin(PI*2*DFT/m));
        for(int k = 0; k < len; k += m)
        {
            Complex w(1,0);
            for(int j = 0; j < m/2; j++)
            {
                Complex u = A[k+j];
                Complex t = A[k+j+m/2]*w;
                A[k+j] = u+t;
                A[k+j+m/2] = u-t;
                w = w*wm;
            }
        }
    }
    if(DFT == -1)
        for(int i = 0; i < len; i++)
            A[i].real /= len, A[i].image /= len;
    for(int i= 0; i < len; i++)
        a[i] = A[i];
}
const int N =300000 + 50;
int T;
Complex C[N];
int a[N];
LL ans[N];
int num[N];
LL sum[N];
int main()
{
    scanf("%d", &T);
    while(T--)
    {
        int n,maxv = -1;
        scanf("%d", &n);
        memset(num, 0, sizeof(num));
        for(int i = 0; i < n; i++)
        {
            scanf("%d", &a[i]);
            num[a[i]]++;
            maxv = max(a[i], maxv);
        }

        int len = 0;
        while((1<<len) <= maxv) len++;
        len = 1<<(len+1);

        for(int i = 0; i < len; i++)
        {
            C[i] = Complex(num[i], 0);
        }
        FFT(C, len, 1);

        for(int i = 0; i < len; i++)
            C[i] = C[i]*C[i];

        FFT(C, len, -1);



        for(int i = 0; i <= 2*maxv; i++)
            ans[i] = (LL)(C[i].real+0.5);
        /*for(int i = 0; i <= 2*maxv; i++)
            printf("%d  ", ans[i]);*/
        for(int i = 0; i < n; i++)
            ans[a[i]+a[i]]--;
        for(int i = 0; i <= 2*maxv; i++)
            ans[i] /= 2;
        sum[0] = 0;
        for(int i = 1; i <= 2*maxv; i++)
            sum[i] = sum[i-1] + ans[i];

        sort(a, a + n);
        LL ret = 0LL;
        for(int i = 1; i <= n; i++)
        {
            LL tmp = sum[2*maxv] - sum[a[i-1]];
            //printf("ret = %lld
", ret);
            tmp -= (LL)(n-i)*(i-1);
            tmp -= (LL)(n-i)*(n-i-1)/2;
            tmp -= (LL)(n-1);

            ret += tmp;
        }
        //cout<<ret<<endl;
        printf("%.7f
", 1.0*ret*6/n/(n-1)/(n-2));

    }
    return 0;
}
如果有错误,请指出,谢谢
原文地址:https://www.cnblogs.com/Alruddy/p/7266804.html