[SDOI 2013]方程

Description

题库链接

求不定方程 (x_1+x_2+cdots +x_n=m) 的正整数解的个数,并且要求满足限定: (forall iin[1,n_1] x_ileq a_i,forall iin[1,n_2] x_{n_1+i}geq a_{n_1+i}) 。对 (p) 取模, (t) 组询问。

(nleq 10^9,n_1leq 8,n_2leq 8,mleq 10^9, pleq 437367875,tleq 5)

Solution

如果没有约束,显然答案就是 (C_{m-1}^{n-1})

对于第二种约束,显然直接在总数中减去就好。

考虑如何处理第一种约束,其实直接容斥就好了,处理方法类似于第二种约束。

注意到模数不一定为质数,还需要扩展 (lucas)

Code

#include <bits/stdc++.h>
using namespace std;

int n, n1, n2, p, t, a[10], ans, m;
int fac[200005], pi[100005], pk[100005], tot;

int quick_pow(int a, int b, int p) {
    int ans = 1;
    while (b) {
    if (b&1) ans = 1ll*ans*a%p;
    b >>= 1, a = 1ll*a*a%p;
    }
    return ans;
}
void ex_gcd(int a, int b, int &x, int &y) {
    if (b == 0) {x = 1; y = 0; return; }
    ex_gcd(b, a%b, x, y);
    int t = x; x = y; y = t-a/b*y;
}
int inv(int a, int p) {
    int x, y; ex_gcd(a, p, x, y);
    return (x%p+p)%p;
}
int mul(int a, int pi, int pk) {
    if (a <= pi) return fac[a];
    int ans = fac[pk]; ans = quick_pow(ans, a/pk, pk);
    ans = 1ll*ans*fac[a%pk]%pk;
    return 1ll*ans*mul(a/pi, pi, pk)%pk;
}
int C(int n, int m, int pi, int pk) {
    int t = 0;
    for (int i = n; i; i /= pi) t += i/pi;
    for (int i = m; i; i /= pi) t -= i/pi;
    for (int i = n-m; i; i /= pi) t -= i/pi;
    if (quick_pow(pi, t, pk) == 0) return 0;
    fac[0] = 1; for (int i = 1; i <= pk; i++) if (i%pi) fac[i] = 1ll*i*fac[i-1]%pk; else fac[i] = fac[i-1];
    int a = mul(n, pi, pk), b = mul(m, pi, pk), c = mul(n-m, pi, pk);
    return 1ll*a*quick_pow(pi, t, pk)%pk*inv(b, pk)%pk*inv(c, pk)%pk;
}
int ex_lucas(int n, int m, int p) {
    int ans = 0;
    for (int i = 1; i <= tot; i++)
    (ans += 1ll*C(n, m, pi[i], pk[i])*(p/pk[i])%p*inv(p/pk[i], pk[i])%p) %= p;
    return ans;
}
void dfs(int c, int r, int f) {
    if (c == n1+1) {
    if (r < n) return;
    (ans += ex_lucas(r-1, n-1, p)*f) %= p; return;
    }
    dfs(c+1, r, f); dfs(c+1, r-a[c], -f);
}
void work() {
    scanf("%d%d", &t, &p);
    int T = p;
    for (int i = 2, x = sqrt(T); i <= x; i++)
    if (T%i == 0) {
        int tol = 1; while (T%i == 0) tol *= i, T /= i;
        pi[++tot] = i, pk[tot] = tol;
    }
    if (T != 1) pi[++tot] = pk[tot] = T;
    while (t--) {
    scanf("%d%d%d%d", &n, &n1, &n2, &m);
    for (int i = 1; i <= n1; i++) scanf("%d", &a[i]);
    for (int i = 1; i <= n2; i++) {scanf("%d", &T); if (T) m -= T-1; }
    ans = 0; dfs(1, m, 1); printf("%d
", (ans+p)%p);
    }
}
int main() {work(); return 0; }
原文地址:https://www.cnblogs.com/NaVi-Awson/p/8660505.html