$NTT$(快速数论变换)

- 概念引入

  - 阶
    对于$p in N_+$且$(a, p) = 1$,满足$a^r equiv 1 (mod p)$的最小的非负$r$为$a$模$p$意义下的阶,记作$delta_p(a)$


  - 原根
    定义:若$p in N_+$且$a in N$,若$delta_p(a) = phi(p)$,则称$a$为模$p$的一个原根
    相关定理:
      - 若一个数$m$拥有原根,那么它必定为$2, 4, p^t, 2p^t (p$为奇质数$)$的其中一个
      - 每个数$p$都有$phi(phi(p))$个原根
      证明:若$p in N_+$且$(a, p) = 1$,正整数$r$满足$a^r equiv 1 (mod p)$,那么$delta(p) | r$,由此推广,可知$delta(p) | phi(p)$,所以$p$的原根个数即为$p$之前与$phi(p)$互质的数,即$phi(p)$故定理成立
      - 若$g$是$m$的一个原根,则$g, g^1, g^2, ..., g^{phi(m)} (mod p)$两两不同
    原根求法:
      将$phi(m)$质因数分解,得$phi(m) = p_1^{c_1} * p_2^{c_2} * ... * p_k^{c_k}$
      那么所有$g$满足$g^{frac{phi(m)}{p_i}} eq 1 (mod m)$即为$m$的原根

- $NTT$

  由于$FTT$涉及到复数的运算,所以常数很大,而$NTT$仅需使用长整型,可大大优化常数

  能够将原根代替单位根进行计算,是因为它们的性质相似,至少在单位根需要的那几个性质原根都满足,当然,要能够进行$NTT$,需要满足模数$p$为质数,且$p = ax + 1$其中$x$为$2$的次幂,那么一般能满足条件的数(常用)有:
  $| p | g |$

  $| 469762049  | 3 |$

  $| 998244353 | 3 |$

  $| 1004535809 | 3 |$
  那么,就可以将单位根$omega_n$替换为$g^{frac{p - 1}{n}}$进行$NTT$了

- 代码 

  1 #include <iostream>
  2 #include <cstdio>
  3 #include <cstring>
  4 #include <algorithm>
  5 #include <cmath>
  6 
  7 #define MOD 998244353
  8 #define g 3
  9 
 10 using namespace std;
 11 
 12 typedef long long LL;
 13 
 14 const int MAXN = (1 << 22);
 15 
 16 LL power (LL x, int p) {
 17     LL cnt = 1;
 18     while (p) {
 19         if (p & 1)
 20             cnt = cnt * x % MOD;
 21 
 22         x = x * x % MOD;
 23         p >>= 1;
 24     }
 25 
 26     return cnt;
 27 }
 28 
 29 const LL invg = power (g, MOD - 2);
 30 
 31 int N, M;
 32 LL A[MAXN], B[MAXN];
 33 
 34 int oppo[MAXN];
 35 int limit;
 36 void NTT (LL* a, int inv) {
 37     for (int i = 0; i < limit; i ++)
 38         if (i < oppo[i])
 39             swap (a[i], a[oppo[i]]);
 40     for (int mid = 1; mid < limit; mid <<= 1) {
 41         LL ome = power (inv == 1 ? g : invg, (MOD - 1) / (mid << 1));
 42         for (int n = mid << 1, j = 0; j < limit; j += n) {
 43             LL x = 1;
 44             for (int k = 0; k < mid; k ++, x = x * ome % MOD) {
 45                 LL a1 = a[j + k], xa2 = x * a[j + k + mid] % MOD;
 46                 a[j + k] = (a1 + xa2) % MOD;
 47                 a[j + k + mid] = (a1 - xa2 + MOD) % MOD;
 48             }
 49         }
 50     }
 51 }
 52 
 53 int getnum () {
 54     int num = 0;
 55     char ch = getchar ();
 56 
 57     while (! isdigit (ch))
 58         ch = getchar ();
 59     while (isdigit (ch))
 60         num = (num << 3) + (num << 1) + ch - '0', ch = getchar ();
 61 
 62     return num;
 63 }
 64 
 65 int main () {
 66     N = getnum (), M = getnum ();
 67     for (int i = 0; i <= N; i ++)
 68         A[i] = (int) getnum ();
 69     for (int i = 0; i <= M; i ++)
 70         B[i] = (int) getnum ();
 71 
 72     int n, lim = 0;
 73     for (n = 1; n <= N + M; n <<= 1, lim ++);
 74     for (int i = 0; i <= n; i ++)
 75         oppo[i] = (oppo[i >> 1] >> 1) | ((i & 1) << (lim - 1));
 76     limit = n;
 77     NTT (A, 1);
 78     NTT (B, 1);
 79     for (int i = 0; i <= n; i ++)
 80         A[i] = A[i] * B[i] % MOD;
 81     NTT (A, - 1);
 82     LL invn = power (n, MOD - 2);
 83     for (int i = 0; i <= N + M; i ++) {
 84         if (i)
 85             putchar (' ');
 86         printf ("%d", (int) (A[i] * invn % MOD));
 87     }
 88     puts ("");
 89 
 90     return 0;
 91 }
 92 
 93 /*
 94 1 2
 95 1 2
 96 1 2 1
 97 */
 98 
 99 /*
100 5 5
101 1 7 4 0 9 4
102 8 8 2 4 5 5
103 */
NTT

- 任意模数$NTT$(三模数$NTT$法)

  有公式

$left{egin{aligned} x equiv a_1 (mod m_1) \ x equiv a_2 (mod m_2) \ x equiv a_3 (mod m_3) end{aligned} ight.$

  直接乘会爆$long long$,就先将上面的用$CRT$合并,得

$left{egin{aligned} x equiv A (mod M) \ x equiv a_3 (mod m_3) end{aligned} ight.$

  那么设$Ans = kM + A$,则有
  

$kM + A equiv a_3 (mod m_3)$
$ k equiv (a_3 - A)M^{- 1} (mod m_3)$

  直接处理即可

- 代码(任意模数$NTT$)

  1 #include <iostream>
  2 #include <cstdio>
  3 #include <cstring>
  4 
  5 using namespace std;
  6 
  7 typedef long long LL;
  8 
  9 const int MAXN = (1 << 20);
 10 
 11 const LL MOD[3]= {469762049, 998244353, 1004535809}; // 三模数
 12 const LL g = 3;
 13 const long double eps = 1e-03;
 14 
 15 LL multi (LL a, LL b, LL p) { // 快速乘
 16     a %= p, b %= p;
 17     return ((a * b - (LL) ((LL) ((long double) a / p * b + eps) * p)) % p + p) % p;
 18 }
 19 LL power (LL x, LL p, LL mod) {
 20     LL cnt = 1;
 21     while (p) {
 22         if (p & 1)
 23             cnt = cnt * x % mod;
 24 
 25         x = x * x % mod;
 26         p >>= 1;
 27     }
 28 
 29     return cnt;
 30 }
 31 const LL invg[3]= {power (g, MOD[0] - 2, MOD[0]), power (g, MOD[1] - 2, MOD[1]), power (g, MOD[2] - 2, MOD[2])};
 32 
 33 int N, M;
 34 LL P;
 35 
 36 LL A[MAXN], B[MAXN];
 37 
 38 int limit;
 39 int oppo[MAXN];
 40 void NTT (LL* a, int inv, int type) {
 41     for (int i = 0; i < limit; i ++)
 42         if (i < oppo[i])
 43             swap (a[i], a[oppo[i]]);
 44     for (int mid = 1; mid < limit; mid <<= 1) {
 45         LL ome = power (inv == 1 ? g : invg[type], (MOD[type] - 1) / (mid << 1), MOD[type]);
 46         for (int n = mid << 1, j = 0; j < limit; j += n) {
 47             LL x = 1;
 48             for (int k = 0; k < mid; k ++, x = x * ome % MOD[type]) {
 49                 LL a1 = a[j + k], xa2 = x * a[j + k + mid] % MOD[type];
 50                 a[j + k] = (a1 + xa2) % MOD[type];
 51                 a[j + k + mid] = (a1 - xa2 + MOD[type]) % MOD[type];
 52             }
 53         }
 54     }
 55 }
 56 
 57 LL ntta[3][MAXN], nttb[3][MAXN];
 58 void NTT_Main () {
 59     int n, lim = 0;
 60     for (n = 1; n <= N + M; n <<= 1, lim ++);
 61     limit = n;
 62     for (int i = 0; i < n; i ++)
 63         oppo[i] = (oppo[i >> 1] >> 1) | ((i & 1) << (lim - 1));
 64     for (int i = 0; i < 3; i ++) {
 65         for (int j = 0; j < n; j ++)
 66             ntta[i][j] = A[j];
 67         for (int j = 0; j < n; j ++)
 68             nttb[i][j] = B[j];
 69         NTT (ntta[i], 1, i);
 70         NTT (nttb[i], 1, i);
 71         for (int j = 0; j < n; j ++)
 72             ntta[i][j] = ntta[i][j] * nttb[i][j] % MOD[i];
 73         NTT (ntta[i], - 1, i);
 74         LL invn = power (n, MOD[i] - 2, MOD[i]);
 75         for (int j = 0; j <= N + M; j ++)
 76             ntta[i][j] = ntta[i][j] * invn % MOD[i];
 77     }
 78 }
 79 
 80 LL ans[MAXN];
 81 void CRT () {
 82     LL m = MOD[0] * MOD[1];
 83     LL M1 = MOD[1], M2 = MOD[0];
 84     LL t1 = power (M1, MOD[0] - 2, MOD[0]), t2 = power (M2, MOD[1] - 2, MOD[1]), invM = power (m % MOD[2], MOD[2] - 2, MOD[2]);
 85     for (int i = 0; i <= N + M; i ++) {
 86         LL a1 = ntta[0][i], a2 = ntta[1][i], a3 = ntta[2][i];
 87         LL A = (multi (a1 * M1 % m, t1 % m, m) + multi (a2 * M2 % m, t2 % m, m)) % m;
 88         LL k = ((a3 - A % MOD[2]) % MOD[2] + MOD[2]) % MOD[2] * invM % MOD[2];
 89         ans[i] = ((k % P * (m % P) % P + A % P) % P + P) % P;
 90     }
 91 }
 92 
 93 int getnum () {
 94     int num = 0;
 95     char ch = getchar ();
 96 
 97     while (! isdigit (ch))
 98         ch = getchar ();
 99     while (isdigit (ch))
100         num = (num << 3) + (num << 1) + ch - '0', ch = getchar ();
101 
102     return num;
103 }
104 
105 int main () {
106     N = getnum (), M = getnum (), P = (LL) getnum ();
107     for (int i = 0; i <= N; i ++)
108         A[i] = (LL) getnum ();
109     for (int i = 0; i <= M; i ++)
110         B[i] = (LL) getnum ();
111 
112     NTT_Main ();
113     CRT ();
114     for (int i = 0; i <= N + M; i ++) {
115         if (i)
116             putchar (' ');
117         printf ("%lld", ans[i]);
118     }
119     puts ("");
120 
121     return 0;
122 }
123 
124 /*
125 5 8 28
126 19 32 0 182 99 95
127 77 54 15 3 98 66 21 20 38
128 */
任意模数NTT(三模数NTT法)
原文地址:https://www.cnblogs.com/Colythme/p/9953329.html