AtCoder AGC034F RNG and XOR (概率期望、FWT)

题目链接

https://atcoder.jp/contests/agc034/tasks/agc034_f

题解

无论多水的题我都不会啊.jpg
首先考虑一个图上随机游走的经典问题,无向图求从(0)号点出发随机游走到每个点的期望时间。做法是显然答案等于从每个点走到(0)号点的期望时间,然后列方程高斯消元。
设答案向量为( extbf{x}), 则有(x_i=sum_{j ext{xor} k=i}p_kx_j+1(1le ilt 2^n))(x_0=0). 即( extbf{x})( extbf{p})异或卷积的结果为(egin{bmatrix}x_0' & x_1-1 & x_2-1 & ... & x_{2^n-1}-1end{bmatrix}). 观察到(sum^{2^n-1}_{i=0}p_i=1),故(sum^{2^n-1}_{i=0}x_i=x_0'+sum^{2^n-1}_{i=1}(x_i-1)), 即(x_0'=x_0+2^n-1). 再用(p_0-1)替换(p_0)( extbf{x})( extbf{p})异或卷积的结果为常数数列( extbf{a}=egin{bmatrix}2^n-1&-1&-1&...&-1end{bmatrix}).
现在已知( extbf{p})( extbf{a}),要求出( extbf{x}). FWT后的数列对应位置作除法即可。设( ext{FWT}( extbf{p})= extbf{P}) (其余字母同理), 则(forall 1le ile 2^n-1, P_ilt sum^{2^n-1}_{i=0}p_i=P_0=0), 也即( extbf{P})序列有且仅有(P_0)(0). (P_0)(A_0)皆为(0), 我们无法还原出(X_0).
( extbf{x}= ext{IFWT}( extbf{X})), 设( extbf{X'}=egin{bmatrix}0&X_1&X_2&...&X_{2^n-1}end{bmatrix}), 则(forall 0le ile 2^n-1, x'_i=x_i-frac{X_0}{2^n}=x_i-(x_0-x'_0)=x_i+x'_0), 故用(x_i=x'_i-x'_0)计算即可。
时间复杂度(O(2^nn)).

代码

#include<bits/stdc++.h>
#define llong long long
using namespace std;

inline int read()
{
	int x = 0,f = 1; char ch = getchar();
	for(;!isdigit(ch);ch=getchar()) {if(ch=='-') f = -1;}
	for(; isdigit(ch);ch=getchar()) {x = x*10+ch-48;}
	return x*f;
}

const int N = 18;
const int P = 998244353;
const llong INV2 = 499122177ll;
llong p[(1<<N)+3];
llong a[(1<<N)+3],b[(1<<N)+3];
int n,sum;

llong quickpow(llong x,llong y)
{
	llong cur = x,ret = 1ll;
	for(int i=0; y; i++)
	{
		if(y&(1ll<<i)) {y-=(1ll<<i); ret = ret*cur%P;}
		cur = cur*cur%P;
	}
	return ret;
}
llong mulinv(llong x) {return quickpow(x,P-2);}

void fwt(int dgr,int coe,llong poly[],llong ret[])
{
	memcpy(ret,poly,sizeof(llong)*(1<<dgr));
	for(int i=0; i<dgr; i++)
	{
		for(int j=0; j<(1<<dgr); j+=(1<<i+1))
		{
			for(int k=0; k<(1<<i); k++)
			{
				llong x = poly[k+j],y = poly[k+(1<<i)+j];
				poly[k+j] = x+y>=P?x+y-P:x+y; poly[k+(1<<i)+j] = x-y<0?x-y+P:x-y;
			}
		}
	}
	if(coe==-1) {llong tmp = mulinv(1<<dgr); for(int i=0; i<(1<<dgr); i++) ret[i] = ret[i]*tmp%P;}
}

int main()
{
	scanf("%d",&n); for(int i=0; i<(1<<n); i++) {scanf("%lld",&p[i]); sum += p[i];} sum = mulinv(sum);
	for(int i=0; i<(1<<n); i++) p[i] = p[i]*sum%P; p[0] = (p[0]-1+P)%P;
	a[0] = (1<<n)-1; for(int i=1; i<(1<<n); i++) a[i] = P-1;
	fwt(n,1,a,a); fwt(n,1,p,p);
	b[0] = 0ll; for(int i=1; i<(1<<n); i++) b[i] = a[i]*mulinv(p[i])%P;
	fwt(n,-1,b,b);
	llong tmp = b[0]; for(int i=0; i<(1<<n); i++) b[i] = (b[i]-tmp+P)%P;
	for(int i=0; i<(1<<n); i++) printf("%lld
",b[i]);
	return 0;
}
原文地址:https://www.cnblogs.com/suncongbo/p/12250156.html