[CF453D]Little Pony and Elements of Harmony

题目

  点这里看题目。

分析

  设(count(x))(x)的二进制中(1)的个数。因此(f(u,v)=count(uoplus v))
  看一看每次转移,我们发现最不友好的东西就是(f(u,v)),因此我们得想办法把它从我们的计算中丢掉。
  发现对于([0,n))中的所有数,两两异或之后不会超过(n)。并且对于一个固定的数(x),其(count(x))是不会变的。因此我们考虑将(b)数组转存出来:

[a[i]=b[count(i)] ]

  因此有:

[e[u]=sum_v a[uoplus v]e[v] ]

  考虑改变枚举顺序:

[egin{aligned} e[u] &=sum_v a[uoplus v]e[v]\ &=sum_{i=0}^ma[i]sum_{uoplus v=i}e[v]\ &=sum_{i=0}^m a[i]sum_{voplus i=u}e[v]\ &=sum_{voplus i=u}a[i]e[v]end{aligned} ]

  因此每次转移都是一个异或卷积的形式,可以用 FWT 优化一发。由于需要转移(t)次,还可用快速幂。 FWT 只需要在初始和最后做,中途快速幂不需要。时间是(O(mn + nlog_2t))
  这里还有一问题。由于本题给的是任意模数,可能不存在(2)逆元。
  众所周知, 异或 FWT 还有一种版本,也就是像 FFT 一样,正变换和逆变换大部分一样,但是逆变换会在最后除掉向量长度(事实上 FWT 和 FFT 有很多相似处,可以在 K 进制 FWT 中了解到)
  因此我们可以使用上述的 FWT 。但是这里还有问题,(p)可能也没有(n)的逆元。根据同余基本性质:

[aequiv bpmod pLeftrightarrow frac a dequiv frac b dpmod {frac p d}(d|gcd{a,b,m}) ]

  我们把(p)扩大(n)倍之后就可以直接除得正确答案了。
  最后一个问题,(n imes p)(10^{15})的,如果直接乘法会溢出(什么? __int128 ?)。因此我们需要用 long double 来模拟取模(龟速乘太慢了)。

代码

#include <cstdio>

typedef long long LL;
typedef long double LB;

const int MAXM = 25, MAXN = 1.5e6 + 5;

template<typename _T>
void read( _T &x )
{
	x = 0;char s = getchar();int f = 1;
	while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
	while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
	x *= f;
}

template<typename _T>
void write( _T x )
{
	if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
	if( 9 < x ){ write( x / 10 ); }
	putchar( x % 10 + '0' );
}

LL E[MAXN], C[MAXN];
int B[MAXM];
LL T, mod;
int N, M;

int lowbit( const int &x ) { return x & ( -x ); }
LL fix( const LL a ) { return ( a % mod + mod ) % mod; }
int count( int x ) { int ret = 0; while( x ) ret ++, x -= lowbit( x ); return ret; }

LL mul( const LL a, const LL b ) { return fix( a * b - ( LL ) ( ( LB ) a / mod * b ) * mod ); }

void FWT( LL *f, const int mode )
{
	LL t1, t2;
	for( int s = 2 ; s <= N ; s <<= 1 )
		for( int i = 0, t = s >> 1 ; i < N ; i += s )
			for( int j = i ; j < i + t ; j ++ )
			{
				t1 = f[j], t2 = f[j + t];
				f[j] = ( t1 + t2 ) % mod, f[j + t] = fix( t1 - t2 );
			}
	if( ~ mode ) return ;
	for( int i = 0 ; i < N ; i ++ ) f[i] /= N;
}

void mul( LL *ret, LL *mult )
{
	for( int i = 0 ; i < N ; i ++ ) 
		ret[i] = mul( ret[i], mult[i] );
}

int main()
{
	read( M ), read( T ), read( mod );
	N = 1 << M, mod *= N;
	for( int i = 0 ; i < N ; i ++ ) read( E[i] );
	for( int i = 0 ; i <= M ; i ++ ) read( B[i] );
	for( int i = 0 ; i < N ; i ++ ) C[i] = B[count( i )];
	FWT( E, 1 ), FWT( C, 1 );
	while( T )
	{
		if( T & 1 ) mul( E, C );
		mul( C, C ), T >>= 1;
	}
	FWT( E, -1 );
	for( int i = 0 ; i < N ; i ++ ) write( E[i] ), puts( "" );
	return 0;
}
原文地址:https://www.cnblogs.com/crashed/p/12583953.html