[集训队互测]calc

题目

点这里看题目。

分析

首先不难想到可以枚举递增的序列,最后在答案里面乘上(n!),于是有(O(nk))的暴力 DP 一枚:

(f(i,j))表示长度为(i)、最大值(le j)的序列的贡献和。

转移显然:

[f(i,j)=j imes f(i-1,j-1)+f(i,j-1) ]

那么可以发现,当序列长度固定的时候,(f(n,x))肯定是关于(x)的函数。环顾四周,DP 转移方程中并不存在除法、开方、作为指数乘方等运算,所以可以推测(f(n,x))就是(x)的多项式函数。

那么,它的次数是多少呢?这直接决定了我们如何进行插值。设(f(n,x))的次数为(g(n)),考虑到转移左右两边的次数应该是相等的,就有:

[f(i,j)-f(i,j-1)=j imes f(i-1,j-1)Rightarrow g(n)-1=g(n-1)+1 ]

补充一下,多项式函数做差分,即(f(x)-f(x-1)),得到的结果的次数会比原多项式的小一,可以直接用二项式定理展开证明。

然后发现,(g(n)=g(n-1)+2),由于(f(0,x)=1),所以(g(0)=0),得到通项公式(g(n)=2n)

然后我们就知道了(f(n,x))是关于(x)(2n)次的多项式函数,因此,我们需要算出(2n+1)个点值,用于插值。总时间(O(n^2))

//f(i,j)=f(i-1,j-1)*j+f(i,j-1)
#include <cstdio>

const int MAXN = 1005;

template<typename _T>
void read( _T &x )
{
	x = 0;char s = getchar();int f = 1;
	while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
	while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
	x *= f;
}

template<typename _T>
void write( _T x )
{
	if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
	if( 9 < x ){ write( x / 10 ); }
	putchar( x % 10 + '0' );
}

int f[MAXN][MAXN], y[MAXN];
int N, K, M, mod;

int qkpow( int base, int indx )
{
	int ret = 1;
	while( indx )
	{
		if( indx & 1 ) ret = 1ll * ret * base % mod;
		base = 1ll * base * base % mod, indx >>= 1;
	}
	return ret;
}

int inver( const int a ) { return qkpow( a, mod - 2 ); }
void add( int &x, const int v ) { x = ( x + v >= mod ? x + v - mod : x + v ); }

int Lagrange()
{
	if( K <= M ) return y[K];
	int ans = 0, tmp;
	for( int i = 1 ; i <= M ; i ++ )
	{
		tmp = 1;
		for( int j = 1 ; j <= M ; j ++ )
			if( i != j )
				tmp = 1ll * tmp * ( K - j ) % mod * inver( i - j + mod ) % mod;
		add( ans, 1ll * tmp * y[i] % mod );
	}
	return ans;
}

int main()
{
	read( K ), read( N ), read( mod );
	M = 2 * N + 1;
	for( int j = 0 ; j <= M ; j ++ )
		f[0][j] = 1; 
	for( int i = 1 ; i <= N ; i ++ )
		for( int j = 1 ; j <= M ; j ++ )
			f[i][j] = ( 1ll * f[i - 1][j - 1] * j % mod + f[i][j - 1] ) % mod;
	for( int i = 0 ; i <= M ; i ++ ) y[i] = f[N][i];
	int fac = 1; 
	for( int i = 1 ; i <= N ; i ++ ) fac = 1ll * fac * i % mod;
	write( 1ll * fac * Lagrange() % mod ), putchar( '
' );
	return 0;
}
原文地址:https://www.cnblogs.com/crashed/p/13127451.html