[Codeforces 1228E]Another Filling the Grid(组合数+容斥)

题目链接

解题思路:

容斥一下好久可以得到式子
(sum_{i=0}^{n}sum_{j=0}^{n}(-1)^{i+j}C_n^iC_n^j(k-1)^{ni+nj-ij}k^{n^2-(ni+nj-ij)})复杂度是(o(n^2logn))但是还能继续化简,
(sum_{i=0}^{n}sum_{j=0}^{n}(-1)^{i+j}C_n^iC_n^j(k-1)^{ni+nj-ij}k^{n^2-(ni+nj-ij)})
(=sum_{i=0}^{n}(-1)^iC_n^isum_{j=0}^{n}(-1)^jC_n^j(k-1)^{(n-i)j+ni}k^{(n-j)(n-i)})
(=sum_{i=0}^{n}(-1)^iC_n^i(k-1)^{ni}sum_{j=0}^{n}(-1)^jC_n^j(k-1)^{(n-i)j}k^{(n-j)(n-i)})
由二项式定理有
(=sum_{i=0}^{n}(-1)^iC_n^i(k-1)^{ni}[k^{n-i}-(k-1)^{n-i}]^n)
(=sum_{i=0}^{n}(-1)^iC_n^i[k^{n-i}(k-1)^i-(k-1)^n]^n)
就能愉快的(O(nlogn))算出答案了

#include <bits/stdc++.h>
using namespace std;
/*    freopen("k.in", "r", stdin);
    freopen("k.out", "w", stdout); */
// clock_t c1 = clock();
// std::cerr << "Time:" << clock() - c1 <<"ms" << std::endl;
//#pragma comment(linker, "/STACK:1024000000,1024000000")
#define de(a) cout << #a << " = " << a << endl
#define rep(i, a, n) for (int i = a; i <= n; i++)
#define per(i, a, n) for (int i = n; i >= a; i--)
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> PII;
typedef pair<double, double> PDD;
typedef vector<int, int> VII;
#define inf 0x3f3f3f3f
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll MAXN = 4e3 + 7;
const ll MAXM = 1e6 + 7;
const ll MOD = 1e9 + 7;
const double eps = 1e-6;
const double pi = acos(-1.0);
ll quick_pow(ll a, ll b)
{
    ll ans = 1;
    while (b)
    {
        if (b & 1)
            ans = (1LL * ans * a) % MOD;
        a = (1LL * a * a) % MOD;
        b >>= 1;
    }
    return ans;
}
int c[305][305];
ll ksm1[305], ksm2[305];
int main()
{
    ll n, k;
    scanf("%lld%lld", &n, &k);
    c[0][0] = 1;
    c[1][0] = c[1][1] = 1;
    for (int i = 2; i <= n; i++)
    {
        c[i][0] = 1;
        for (int j = 1; j <= i; j++)
            c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % MOD;
    }
    ksm1[0] = ksm2[0] = 1;
    for (int i = 1; i <= n; i++)
        ksm1[i] = (ksm1[i - 1] * k) % MOD, ksm2[i] = (ksm2[i - 1] * (k - 1)) % MOD;
    ll ans = 0;
    ll t = 1;
    for (int i = 0; i <= n; i++)
    {
        ans += t * c[n][i] * quick_pow((ksm1[n - i] * ksm2[i] - ksm2[n]) % MOD, n) % MOD;
        t *= -1;
        ans %= MOD;
    }
    printf("%lld
", (ans % MOD + MOD) % MOD);
    return 0;
}
原文地址:https://www.cnblogs.com/graytido/p/11885403.html