题解
题目等价于求这个式子
[ans=n2^{frac{(n-1)(n-2)}{2}}sumlimits_{i=0}^{n-1}inom{n-1}{i}i^k
]
有这么一个式子
[i^k=sumlimits_{j=0}^{i}egin{Bmatrix}
k\
j
end{Bmatrix}j!inom{i}{j}]
代入可得
[ans=n2^{frac{(n-1)(n-2)}{2}}sumlimits_{i=0}^{n-1}inom{n-1}{i}sumlimits_{j=0}^{i}egin{Bmatrix}
k\
j
end{Bmatrix}j!inom{i}{j}]
交换枚举顺序
[ans=n2^{frac{(n-1)(n-2)}{2}}sumlimits_{j=0}^{n-1}egin{Bmatrix}
k\
j
end{Bmatrix}j!sumlimits_{i=j}^{n-1}inom{n-1}{i}inom{i}{j}]
考虑到后面那个和号的组合意义为先在(n-1)个数中确定(j)个,剩下的可选可不选,即
[ans=n2^{frac{(n-1)(n-2)}{2}}sumlimits_{j=0}^{n-1}egin{Bmatrix}
k\
j
end{Bmatrix}j!inom{n-1}{j}2^{n-1-j}
]
[=n2^{frac{(n-1)(n-2)}{2}}sumlimits_{j=0}^{n-1}egin{Bmatrix}
k\
j
end{Bmatrix}frac{(n-1)!}{(n-1-j)!}2^{n-1-j}]
本题的(n)可能高达(10^9),但是发现当(j>k)时(egin{Bmatrix} k\ j end{Bmatrix})为(0),改一下求和上界
[=n2^{frac{(n-1)(n-2)}{2}}sumlimits_{j=0}^{min{n-1,k}}egin{Bmatrix}
k\
j
end{Bmatrix}frac{(n-1)!}{(n-1-j)!}2^{n-1-j}]
第二类斯特林数可以直接卷积出来,总复杂度(O(nlogn))
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <string>
#include <vector>
#include <cmath>
#include <ctime>
#include <queue>
#include <map>
#include <set>
using namespace std;
#define ull unsigned long long
#define pii pair<int, int>
#define uint unsigned int
#define mii map<int, int>
#define lbd lower_bound
#define ubd upper_bound
#define INF 0x3f3f3f3f
#define IINF 0x3f3f3f3f3f3f3f3fLL
#define DEF 0x8f8f8f8f
#define DDEF 0x8f8f8f8f8f8f8f8fLL
#define vi vector<int>
#define ll long long
#define mp make_pair
#define pb push_back
#define re register
#define il inline
#define N 1000000
#define MOD 998244353
int n, k;
int a[N+5], b[N+5], S[N+5], fac[N+5], facinv[N+5];
int fpow(int x, int p) {
int ret = 1;
while(p) {
if(p&1) ret = 1LL*ret*x%MOD;
x = 1LL*x*x%MOD;
p >>= 1;
}
return ret;
}
void bitReverse(int *s, int bit, int len) {
static int tmp[4*N+5];
tmp[0] = 0;
for(int i = 1; i < len; ++i) {
tmp[i] = (tmp[i>>1]>>1)|((i&1)<<(bit-1));
if(i < tmp[i]) swap(s[i], s[tmp[i]]);
}
}
void DFT(int *s, int bit, int len, int flag) {
bitReverse(s, bit, len);
for(int l = 1; l <= len; l <<= 1) {
int mid = l>>1, t = fpow(3, (MOD-1)/l);
if(flag) t = fpow(t, MOD-2);
for(int *p = s; p != s+len; p += l) {
int w = 1, x, y;
for(int i = 0; i < mid; ++i) {
x = p[i], y = 1LL*w*p[i+mid]%MOD;
p[i] = (x+y)%MOD;
p[i+mid] = (x-y)%MOD;
w = 1LL*w*t%MOD;
}
}
}
if(flag) {
int invlen = fpow(len, MOD-2);
for(int i = 0; i < len; ++i) s[i] = 1LL*s[i]*invlen%MOD;
}
}
int main() {
scanf("%d%d", &n, &k);
if(n == 1) {
printf("0
");
return 0;
}
fac[0] = 1;
for(int i = 1; i <= k; ++i) fac[i] = 1LL*fac[i-1]*i%MOD;
facinv[k] = fpow(fac[k], MOD-2);
for(int i = k; i >= 1; --i) facinv[i-1] = 1LL*facinv[i]*i%MOD;
for(int i = 0; i <= k; ++i) {
a[i] = facinv[i];
if(i&1) a[i] *= -1;
b[i] = 1LL*fpow(i, k)*facinv[i]%MOD;
}
int bit = 0, len;
while((1<<bit) < 2*k+2) bit++;
len = (1<<bit);
DFT(a, bit, len, 0), DFT(b, bit, len, 0);
for(int i = 0; i < len; ++i) S[i] = 1LL*a[i]*b[i]%MOD;
DFT(S, bit, len, 1);
int ans = 0, lim = min(n-1, k), x = 1, y = fpow(2, n-1), t = fpow(2, MOD-2);
for(int i = 0; i <= lim; ++i) {
ans = (ans+1LL*S[i]*x%MOD*y%MOD)%MOD;
x = 1LL*x*(n-1-i)%MOD, y = 1LL*y*t%MOD;
}
if(n&1) ans = 1LL*n*fpow(fpow(2, (n-1)/2), n-2)%MOD*ans%MOD;
else ans = 1LL*n*fpow(fpow(2, (n-2)/2), n-1)%MOD*ans%MOD;
ans = (ans+MOD)%MOD;
printf("%d
", ans);
return 0;
}