多项式学习笔记(三): 多项式全家桶

1.多项式求逆

给你 (A(x))(A(x)B(x) equiv 1 pmod {x^n}) 。 (模 (x^n) 是为了把高次项舍掉)

假设我们已经得到了满足 (C(x)A(x) equiv 1 pmod {x^{nover 2}}) 的一个多项式 (C)

那么由题意可得 (A(x)B(x)equiv 1 pmod {x^{nover 2}})

两式联立可得:

(B(x) equiv C(x) pmod {x^{nover 2}})

(B(x) - C(x) equiv 0 pmod {x^{nover 2}})

两边同时平方可得:

(B^2(x) + C^2(x) - 2B(x)C(x) equiv 0 pmod {x^{n}})

在同时乘上一个 (A(x)) 得:

(A(x)B^2(x) + A(x)C^2(x)-2A(x)B(x)C(x)equiv 0 pmod {x^{n}})

然后由题意可得 (A(x)B(x)equiv 1 pmod {x^n}) ,代入化简可得:

(B(x) + A(x)C^2(x)-2C(x) equiv 0 pmod {x^n})

(B(x) = 2C(x) - A(x)C^2(x))

然后,我们每次都可以把项数减半递归求解, 如果项数为 (1) 的话结果显然是零次项的逆元。

复杂度 (T(n) = T({nover 2}) + nlogn = nlogn)

Code

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
#define int long long
const int N = 1e6+10;
const int p = 998244353;
int n,a[N],b[N],rev[N],c[N];
inline int read()
{
    int s = 0,w = 1; char ch = getchar();
    while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
    return s * w;
}
int ksm(int a,int b)
{
    int res = 1;
    for(; b; b >>= 1)
    {
        if(b & 1) res = res * a % p;
        a = a * a % p;
    }
    return res;
}
void NTT(int *a,int len,int opt)
{
    for(int i = 0; i < len; i++)
    {
        if(i < rev[i]) swap(a[i],a[rev[i]]);
    }
    for(int h = 1; h < len; h <<= 1)
    {
        int wn = ksm(3,(p-1)/(h<<1));
        if(opt == -1) wn = ksm(wn,p-2);
        for(int j = 0; j < len; j += (h<<1))
        {
            int w = 1;
            for(int k = 0; k < h; k++)
            {
                int u = a[j + k];
                int v = w * a[j + h + k] % p;
                a[j + k] = (u + v) % p;
                a[j + h + k] = (u - v + p) % p;
                w = w * wn % p;
            }
        }
    }
    if(opt == -1)
    {
        int inv = ksm(len,p-2);
        for(int i = 0; i < len; i++) a[i] = a[i] * inv % p;
    }
}
void Inv(int n,int *a,int *b)//求 A(x)B(x) = 1 mod x^n
{
    if(n == 1)//项数为1的情况
    {
        b[0] = ksm(a[0],p-2);
        return;
    }
    Inv((n+1)>>1,a,b);//递归求 C(x)
    int lim = 1, tim = 0;
    while(lim < (n<<1)) lim <<= 1, tim++;
    for(int i = 0; i < lim; i++)//预处理NTT的反转数组
    {
        rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    }
    //注意,不能用 a 来做多项式乘法,因为如果拿 a 做了多项式乘法,那么 a 的值在递归过程中,就会发生改变。
    for(int i = 0; i < n; i++) c[i] = a[i];//把 a 赋给 c,用 c 来做多项式乘法
    for(int i = n; i < lim; i++) c[i] = 0;//多余的高次项舍去
    //此时的 B 数组存的是 B(x)A(x) = 1 mod x^{n/2},C数组存的是 A(x)
    NTT(c,lim,1); NTT(b,lim,1);//求 B 和 C 的点值表示法
    for(int i = 0; i < lim; i++) b[i] = (2 * b[i] % p - b[i] * b[i] % p * c[i] % p + p) % p;//计算 B的点值
    NTT(b,lim,-1);//把B转化为系数表示法
    for(int i = n; i < lim; i++) b[i] = 0;//高次项舍去 
}
signed main()
{
    n = read();
    for(int i = 0; i < n; i++) a[i] = read();
    Inv(n,a,b);
    for(int i = 0; i < n; i++) printf("%lld ",b[i]);
    printf("
");
    return 0;
}

2.多项式开根

(B^2(x) equiv A(x) pmod {x^n})

假设,我们得到了满足 (C^2(x) equiv A(x) pmod {x^{nover 2}}) 的一个多项式 (C(x))

又因为 (B^2(x) equiv A(x) pmod {x^{nover 2}})

两式联立可得:

(B^2(x) equiv C^2(x) pmod {x^{nover 2}})

(B^2(x)-C^2(x) equiv 0 pmod {x^{nover 2}})

两边同时平方可得:

(B^4(x) + C^4(x) - 2B^2(x)C^2(x) equiv 0 pmod {x^n})

两边同时加上 (4B^2(x)C^2(x)) 可得:

(B^4(x) + C^4(x) + 2B^2(x)C^2(x) equiv 4B^2(x)C^2(x) pmod {x^n})

((B^2(x) + C^2(x))^2 equiv 4B^2(x)C^2(x) pmod {x^n})

把右边的 (4C^2(x)) 除过去可得:

({(B^2(x) + C^2(x))^2 over 4C^2(x)}equiv B^2(x)pmod {x^n})

(B(x) equiv {B^2(x) + C^2(x)over 2C(x)} pmod {x^n})

又因为 (B^2(x) equiv A(x) pmod {x^n}) ,代入可得:

(B(x) equiv {A(x) + C^2(x)over 2C(x)} pmod {x^n})

还是像求逆一样每次项数减半,递归求解,当项数为 (1) 的时候答案为 (sqrt {常数项})

多项式求逆加NTT即可。

复杂度 (O(nlogn))

Code(常数爆炸):

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;
#define int long long
const int N = 1e6+10;
const int p = 998244353;
int n,a[N],b[N],c[N],d[N],rev[N];
inline int read()
{
    int s = 0,w = 1; char ch = getchar();
    while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
    return s * w;
}
int ksm(int a,int b)
{
    int res = 1;
    for(; b; b >>= 1)
    {
        if(b & 1) res = res * a % p;
        a = a * a % p;
    }
    return res;
}
void NTT(int *a,int len,int opt)//NTT 板子
{
    for(int i = 0; i < len; i++)
    {
        if(i < rev[i]) swap(a[i],a[rev[i]]);
    }
    for(int h = 1; h < len; h <<= 1)
    {
        int wn = ksm(3,(p-1)/(h<<1));
        if(opt == -1) wn = ksm(wn,p-2);
        for(int j = 0; j < len; j += (h<<1))
        {
            int w = 1;
            for(int k = 0; k < h; k++)
            {
                int u = a[j + k];
                int v = w * a[j + h + k] % p;
                a[j + k] = (u + v) % p;
                a[j + h + k] = (u - v + p) % p;
                w = w * wn % p;
            }
        }
    }
    if(opt == -1)
    {
        int inv = ksm(len,p-2);
        for(int i = 0; i < len; i++) a[i] = a[i] * inv % p;
    }
}
void Inv(int n,int *a,int *b)//多项式求逆板子
{
    if(n == 1)
    {
        b[0] = ksm(a[0],p-2);
        return;
    }
    Inv((n+1)>>1,a,b);
    int lim = 1, tim = 0;
    while(lim < (n<<1)) lim <<= 1, tim++;
    for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    for(int i = 0; i < n; i++) c[i] = a[i];
    for(int i = n; i < lim; i++) c[i] = 0;
    NTT(c,lim,1); NTT(b,lim,1);
    for(int i = 0; i < lim; i++) b[i] = (2 * b[i] % p - b[i] * b[i] % p * c[i] % p + p) % p;
    NTT(b,lim,-1);
    for(int i = n; i < lim; i++) b[i] = 0;//记得清空
}
void sqrt(int n,int *a,int *b)
{
    if(n == 1)//项数为 1的情况
    {
        b[0] = (int) sqrt(a[0]);
        return;
    } 
    sqrt((n+1)>>1,a,b);    
	Inv(n,b,d);//这里求 mod x^n 下的逆元,而不是 mod x^lim 下的逆元 
    int lim = 1, tim = 0;
    while(lim < (n<<1)) lim <<= 1, tim++;
    for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    for(int i = 0; i < n; i++) c[i] = a[i];//用c数组代替a来做多项式乘法
    for(int i = n; i < lim; i++) c[i] = 0;
    //这里 b 数组存的是 C^2(x) = A(x) mod x^{n/2}
    // c数组 存的是 A(x), d数组存的是 C(x) 的乘法逆
    NTT(b,lim,1); NTT(c,lim,1); NTT(d,lim,1);
    int inv2 = ksm(2,p-2);
    for(int i = 0; i < lim; i++) b[i] = (b[i] * b[i] % p + c[i] % p) * d[i] % p * inv2 % p;//根据柿子算出 B(x) 的点值
    NTT(b,lim,-1);//转换为系数表示法
    for(int i = n; i < lim; i++) b[i] = 0;   
    for(int i = 0; i < lim; i++) d[i] = 0;//多次调用要清空
} 
signed main()
{
    n = read();
    for(int i = 0; i < n; i++) a[i] = read();
    sqrt(n,a,b);
    for(int i = 0; i < n; i++) printf("%lld ",b[i]);
    return 0;
}

3.多项式求导

(A(x) = displaystylesum_{i=0}^{n} a_ix^i) , 则 (A^prime(x) = displaystylesum_{i=0}^{n} ia_{i}x^{i-1})

void qiudao(int len,int *a,int *b)
{
    for(int i = 1; i < len; i++) b[i-1] = i * a[i] % p;
    b[len-1] = 0;
}

5.多项式积分

(A(x) = displaystylesum_{i=0}^{n}a_ix^i) ,则 (int A(x) = displaystylesum_{i=1}^{n} {a_iover i+1} x^{i+1})

void jifen(int len,int *a,int *b)
{
    for(int i = 1; i < len; i++) b[i] = a[i-1] * ksm(i,p-2) % p;
    b[0] = 0;
}

6.多项式 ln

(B(x) equiv lnA(x) pmod {x^n})

(F(x) = lnA(x)) ,则 对等式两边同时求导可得:

(B^prime(x) equiv F^prime(x) pmod {x^n})

根据复合函数求导公式 (f^prime(g(x)) = f^prime(g(x)) g^prime(x)) 可得:

(B^prime(x) equiv {A^prime (x)over A(x)} pmod {x^n})

先求出 (A(x)) 的导函数和乘法逆,在相乘得到 (B^prime(x)) ,最后在积分回去即可。

多项式求逆,多项式求导,多项式积分,多项式乘法。

复杂度 (O(nlogn))

code

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
#define int long long
const int p = 998244353;
const int N = 1e6+10;
int n,a[N],b[N],c[N],rev[N],A[N],B[N];
inline int read()
{
    int s = 0,w = 1; char ch = getchar();
    while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
    return s * w;
}
int ksm(int a,int b)
{
    int res = 1;
    for(; b; b >>= 1)
    {
        if(b & 1) res = res * a % p;
        a = a * a % p;
    }
    return res;
}
void NTT(int *a,int len,int opt)
{
    for(int i = 0; i < len; i++)
    {
        if(i < rev[i]) swap(a[i],a[rev[i]]);
    }
    for(int h = 1; h < len; h <<= 1)
    {
        int wn = ksm(3,(p-1)/(h<<1));
        if(opt == -1) wn = ksm(wn,p-2);
        for(int j = 0; j < len; j += (h<<1))
        {
            int w = 1;
            for(int k = 0; k < h; k++)
            {
                int u = a[j + k];
                int v = w * a[j + h + k] % p;
                a[j + k] = (u + v) % p;
                a[j + h + k] = (u - v + p) % p;
                w = w * wn % p;
            }
        }
    }
    if(opt == -1)
    {
        int inv = ksm(len,p-2);
        for(int i = 0; i < len; i++) a[i] = a[i] * inv % p;
    }
}
void Inv(int n,int *a,int *b)
{
    if(n == 1)
    {
        b[0] = ksm(a[0],p-2);
        return;
    }
    Inv((n+1)>>1,a,b);
    int lim = 1, tim = 0;
    while(lim < (n<<1)) lim <<= 1, tim++;
    for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    for(int i = 0; i < n; i++) c[i] = a[i];
    for(int i = n; i < lim; i++) c[i] = 0;
    NTT(b,lim,1); NTT(c,lim,1);
    for(int i = 0; i < lim; i++) b[i] = (2 * b[i] % p - b[i] * b[i] % p * c[i] % p + p) % p;
    NTT(b,lim,-1);
    for(int i = n; i < lim; i++) b[i] = 0;
}
void qiudao(int len,int *a,int *b)
{
    for(int i = 1; i < len; i++) b[i-1] = i * a[i] % p;
    b[len-1] = 0;
}
void jifen(int len,int *a,int *b)
{
    for(int i = 1; i < len; i++) b[i] = a[i-1] * ksm(i,p-2) % p;
    b[0] = 0;
}
void Ln(int n,int *a,int *b)
{
    Inv(n,a,A); qiudao(n,a,B);//A 存的是 a的乘法逆,B存的是 a的导函数
    int lim = 1, tim = 0;
    while(lim < (n<<1)) lim <<= 1, tim++;
    for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    NTT(A,lim,1); NTT(B,lim,1);
    for(int i = 0; i < lim; i++) B[i] = B[i] * A[i] % p;
    NTT(B,lim,-1); jifen(lim,B,b);//B存的是 b 的导函数
    for(int i = n; i < lim; i++) b[i] = 0;
}
signed main()
{
    n = read();
    for(int i = 0; i < n; i++) a[i] = read();
    Ln(n,a,b);
    for(int i = 0; i < n; i++) printf("%lld ",b[i]);
    return 0;
}

7.多项式除法

给你一个 (n) 次多项式 (A(x)) 和一个 (m) 次的多项式 (B(x)),求多项式 (C(x))(D(x)) 满足:

  1. (C(x)) 的次数为 (n-m), (D(x)) 的次数小于 (m)
  2. (A(x) = C(x) * B(x) + D(x))

(f(x)) 是一个 (n) 次多项式,则定义 (inv(f(x)) = x^nf({1over x}))

(inv(f(x)) = x^n f({1over x}) = x^n(a_0+a_1x^{-1}+...a_nx^{-n}) = a_{n} + a_{n-1}x^1 + a_{n-2}x^2+....a_{1}x^{n-1} + a_0x^{n})

所以 (inv(f(x))) 其实就是把 (f(x)) 的系数反转过来得到的结果。

(ecause A(x) = C(x) * B(x) + D(x))

所以有 (inv(A(x)) = inv(C(x) * B(x) + D(x)))

展开可得:

(x^nA({1over x}) = x^{n} (C({1over x}) * B({1over x}) + D({1over x})))

(x^nA({1over x}) = x^mB({1over x}) x^{n-m} C({1over x}) + x^{n-m+1} x^{m-1} D({1over x}))

在转化为 (inv(f(X))) 可得:

(inv(A(x)) = inv(B(x))inv(C(x)) + x^{n-m+1}inv(D(x)))

两边同时模上 (x^{n-m+1}) 可得:

(invA(x) equiv inv(B(x))inv(C(x)) pmod {x^{n-m+1}})

(inv(C(x)) equiv {inv(A(x))over invB(x)} pmod {x^{n-m+1}})

多项式乘法和多项式求逆可以求出来 (inv(C(x))), 在把系数反转得到 (C(x)).

最后把 (C(x)) 代入原式可得到 (D(x)).

复杂度 (O(nlogn))

一定要注意清空数组(我这个沙比就因为这个卡在了50分好几回)

Code:

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
#define int long long
const int p = 998244353;
const int N = 1e6+10;
int n,m,rev[N],a[N],b[N],c[N],d[N],A[N],B[N],invB[N];
inline int read()
{
    int s = 0,w = 1; char ch = getchar();
    while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
    return s * w;
}
int ksm(int a,int b)
{
    int res = 1;
    for(; b; b >>= 1)
    {
        if(b & 1) res = res * a % p;
        a = a * a % p;
    }
    return res;
}
void NTT(int *a,int len,int opt)
{
    for(int i = 0; i < len; i++)
    {
        if(i < rev[i]) swap(a[i],a[rev[i]]);
    }
    for(int h = 1; h < len; h <<= 1)
    {
        int wn = ksm(3,(p-1)/(h<<1));
        if(opt == -1) wn = ksm(wn,p-2);
        for(int j = 0; j < len; j += (h<<1))
        {
            int w = 1;
            for(int k = 0; k < h; k++)
            {
                int u = a[j + k];
                int v = w * a[j + h + k] % p;
                a[j + k] = (u + v) % p;
                a[j + h + k] = (u - v + p) % p;
                w = w * wn % p;
            }
        }
    }
    if(opt == -1)
    {
        int inv = ksm(len,p-2);
        for(int i = 0; i < len; i++) a[i] = (a[i] * inv % p + p) % p;
    }
}
void Inv(int n,int *a,int *b)
{
    if(n == 1)
    {
        b[0] = ksm(a[0],p-2);
        return;
    }
    Inv((n+1)>>1,a,b);
    int lim = 1, tim = 0;
    while(lim < (n<<1)) lim <<= 1, tim++;
    for(int i = 0; i < lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
    for(int i = 0; i < n; i++) c[i] = a[i];
    for(int i = n; i < lim; i++) c[i] = 0;
    NTT(c,lim,1); NTT(b,lim,1);
    for(int i = 0; i < lim; i++) b[i] = (2 * b[i] % p - b[i] * b[i] % p * c[i] % p + p) % p;
    NTT(b,lim,-1); 
    for(int i = n; i < lim; i++) b[i] = 0;
}
void mul(int n,int m,int *a,int *b)
{
	int lim = 1, tim = 0;
	while(lim < (n<<1)) lim <<= 1, tim++;
	for(int i = 0; i <lim; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(tim-1));
	NTT(a,lim,1); NTT(b,lim,1);
	for(int i = 0; i < lim; i++) a[i] = a[i] * b[i] % p;
	NTT(a,lim,-1);
	for(int i = n; i < lim; i++) a[i] = 0; 
}
void Chu(int n,int m,int *a,int *b)
{
    for(int i = 0; i < n; i++) A[i] = a[n-i-1];//A 数组存的是 inv(A(x))
    for(int i = 0; i < m; i++) B[i] = b[m-i-1];//B 数组存的是 inv(B(x))
    Inv(n-m+1,B,invB); 
    for(int i = n-m+1; i < (n<<2); i++) A[i] = invB[i] = 0;
    mul(n-m+1,n-m+1,A,invB); 
    for(int i = 0; i < n-m+1; i++) c[i] = (A[n-m-i] % p + p) % p;
	for(int i = 0; i < n-m+1; i++) printf("%lld ",c[i]); 
    printf("
");
    for(int i = n-m+1; i < (n<<2); i++) c[i] = 0;
    mul(n,n,c,b);
    for(int i = 0; i < m-1; i++) d[i] = ((a[i] - c[i]) % p + p) % p;
    for(int i = 0; i < m-1; i++) printf("%lld ",d[i]);
} 
signed main()
{
    n = read() + 1; m = read() + 1;
    for(int i = 0; i < n; i++) a[i] = read();
    for(int i = 0; i < m; i++) b[i] = read();
    Chu(n,m,a,b);
    return 0;
}
原文地址:https://www.cnblogs.com/genshy/p/14260473.html