[bzoj4361]isn【dp】【容斥】

【题目描述】

Description

给出一个长度为n的序列A(A1,A2...AN)。如果序列A不是非降的,你必须从中删去一个数,
这一操作,直到A非降为止。求有多少种不同的操作方案,答案模10^9+7。

Input

第一行一个整数n。
接下来一行n个整数,描述A。

Output

一行一个整数,描述答案。

Sample Input

4
1 7 5 3

Sample Output

18

HINT

1<=N<=2000

Source

【题解】

        一道有些麻烦的dp。

        记 f[i] 表示长度为 i 的非降序列的方案数。

        可以通过树状数组优化到O(n^2 log n)。// 从小往大做

        如果不考虑重复 ans=∑f[i]*(n-i)!

        但会有重复的情况,因为可能在长度为 i+1 时序列已经是非降的。

        一个长度为 i 的非降序列,会产生 i 个重复的长度为 i-1 的非降序列

        因此 ans=ans - ∑f[i]*(n-i)!*i (i>1)。

 
/* --------------
    user Vanisher
    problem bzoj-4361 
----------------*/
# include <bits/stdc++.h>
# define    N       2010
# define    inf     1e9
# define    ll      long long
# define    P       1000000007
using namespace std;
struct node{
    int a,p;
}p[N];
int h[N][N];
long long f[N],n,mul[N],g[N];
bool cmpa(node a, node b){
    return a.a<b.a;
}
bool cmpp(node a, node b){
    return a.p<b.p;
}
int lowbit(int x){
    return x&(-x);
}
ll mypow(ll x, int y){
    ll a=x; x=1;
    while (y>0){
        if (y%2==1) x=x*a%P;
        a=a*a%P;
        y/=2;
    }
    return x;
}
int query(int i, int x){
    int num=0;
    while (x>0){
        num=(num+h[i][x])%P;
        x=x-lowbit(x);
    }
    return num;
}
void modify(int i, int x, int k){
    while (x<=n){
        h[i][x]=(h[i][x]+k)%P;
        x=x+lowbit(x);
    }
}
int main(){
    scanf("%d",&n);
    for (ll i=1; i<=n; i++){
        scanf("%d",&p[i].a);
        p[i].p=i;
    }
    mul[0]=1;
    for (ll i=1; i<=n; i++) mul[i]=mul[i-1]*i%P;
    sort(p+1,p+n+1,cmpa);
    ll la=-inf,num=0;
    for (int i=1; i<=n; i++)
        if (p[i].a!=la){
            la=p[i].a;
            p[i].a=++num;
        }
        else p[i].a=num;
    for (int i=1; i<=n; i++){
        g[i]=1;
        for (int j=1; j<i; j++)
            g[i]=max(g[i],g[j]+1);
    }
    sort(p+1,p+n+1,cmpp);
    modify(0,1,1);
    for (int i=1; i<=n; i++)
        for (int j=g[i]; j>=1; j--){
            int now=query(j-1,p[i].a);
            modify(j,p[i].a,now);
        }
    for (int i=1; i<=n; i++)
        f[i]=query(i,n);
    ll ans=0;
    for (ll i=1; i<=n; i++)
        ans=(ans+f[i]*mul[n-i])%P;
    for (ll i=2; i<=n; i++)
        ans=(ans-f[i]*mul[n-i]%P*i)%P;
    printf("%lld
",(ans+P)%P);
    return 0;
}


原文地址:https://www.cnblogs.com/Vanisher/p/9136018.html