传送门
直接的想法就是设 (x^k) 为边权,矩阵树定理一波后取出 (x^{nk}) 的系数即可
也就是求出模 (x^k) 意义下的循环卷积的常数项
考虑插值出最后多项式,类比 (DFT) 的方法
假设我们要求
[C_i=sum_{j=0}^{n}sum_{k=0}^{n}A_jB_k[(j+k)~mod~n=i]
]
(A,B,C) 为多项式
我们知道了 (A,B) 的 (n) 个点值
[A(w_n^i)=sum_{k=0}^{n}A_kw_n^{ik}
]
[B(w_n^i)=sum_{k=0}^{n}B_kw_n^{ik}
]
那么
[C(w_n^k)=sum_{i=0}^{n}sum_{j=0}^{n}A_iw_n^{ik}B_jw_n^{jk}=sum_{i=0}^{n}sum_{j=0}^{n}A_iB_jw_n^{k(i+j)}
]
而根据消去引理 (w_n^{k(i+j)}=w_n^{k((i+j)~mod~n)})
所以
[C(w_n^k)=sum_{l=0}^{n}sum_{i=0}^{n}sum_{j=0}^{n}[(i+j)~mod~n=l]A_iB_jw_n^{kl}
]
正好对应了循环卷积,所以只要求得到 (w_n^{k},(k=0...n-1)) 的点值就可以得到最后的多项式了
这道题 (p~mod~k=1) 所以直接用原根就好了,最后插值一下
upd: 其实最后并不需要插值
根据单位根反演
[[k|x]=frac{1}{k}sum_{i=0}^{k-1}omega_{k}^{ix}
]
把多项式的每一项都换成这个东西,得到的值就是要的答案
也就是说直接带入每一个单位根,把矩阵树定理得到的权值加起来最后除去 (k) 就好了
# include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n, m, k, mod, g, pr[233333], tot, a[105][105];
int xi[105], yi[105];
struct Edge {
int u, v, w;
} edge[10005];
inline int Pow(ll x, int y) {
register ll ret = 1;
for (; y; y >>= 1, x = x * x % mod)
if (y & 1) ret = ret * x % mod;
return ret;
}
inline void Inc(int &x, int y) {
x = x + y >= mod ? x + y - mod : x + y;
}
inline void Getrt() {
register int x, i, j;
for (x = mod - 1, i = 2; i * i <= x; ++i)
if (x % i == 0) {
pr[++tot] = i;
while (x % i == 0) x /= i;
}
if (x > 1) pr[++tot] = x;
for (x = mod - 1, i = 2; i <= x; ++i) {
for (g = i, j = 1; g && j <= tot; ++j)
if (Pow(g, x / pr[j]) == 1) g = 0;
if (g) break;
}
}
inline int Gauss() {
register int ans = 1, i, j, l, inv;
for (i = 1; i < n; ++i) {
for (j = i; j < n; ++j)
if (a[j][i]) {
if (i != j) swap(a[i], a[j]), ans = mod - ans;
break;
}
for (j = i + 1; j < n; ++j)
if (a[j][i]) {
inv = (ll)a[j][i] * Pow(a[i][i], mod - 2) % mod;
for (l = i; l < n; ++l) Inc(a[j][l], mod - (ll)a[i][l] * inv % mod);
}
ans = (ll)ans * a[i][i] % mod;
}
return ans;
}
int main() {
register int i, j, w, u, v, ans;
scanf("%d%d%d%d", &n, &m, &k, &mod), Getrt();
for (i = 1; i <= m; ++i) scanf("%d%d%d", &edge[i].u, &edge[i].v, &edge[i].w);
xi[0] = 1, xi[1] = Pow(g, (mod - 1) / k);
for (i = 0; i < k; ++i) {
if (i > 1) xi[i] = (ll)xi[i - 1] * xi[1] % mod;
memset(a, 0, sizeof(a));
for (j = 1; j <= m; ++j) {
u = edge[j].u, v = edge[j].v, w = Pow(xi[i], edge[j].w);
Inc(a[u][u], w), Inc(a[v][v], w), Inc(a[u][v], mod - w), Inc(a[v][u], mod - w);
}
yi[i] = Gauss();
}
for (i = ans = 0; i < k; ++i) {
for (w = yi[i], j = 0; j < k; ++j)
if (i ^ j) w = (ll)w * (mod - xi[j]) % mod * Pow((xi[i] + mod - xi[j]) % mod, mod - 2) % mod;
Inc(ans, w);
}
printf("%d
", ans);
return 0;
}