BZOJ2004: [Hnoi2010]Bus 公交线路

BZOJ2004: [Hnoi2010]Bus 公交线路

Description

小Z所在的城市有N个公交车站,排列在一条长(N-1)km的直线上,从左到右依次编号为1到N,相邻公交车站间的距离均为1km。
作为公交车线路的规划者,小Z调查了市民的需求,决定按下述规则设计线路:
1.设共K辆公交车,则1到K号站作为始发站,N-K+1到N号台作为终点站。
2.每个车站必须被一辆且仅一辆公交车经过(始发站和终点站也算被经过)。 
3.公交车只能从编号较小的站台驶往编号较大的站台。 
4.一辆公交车经过的相邻两个站台间距离不得超过Pkm。
在最终设计线路之前,小Z想知道有多少种满足要求的方案。
由于答案可能很大,你只需求出答案对30031取模的结果。

Input

仅一行包含三个正整数N K P,分别表示公交车站数,公交车数,相邻站台的距离限制。
N<=10^9,1<P<=10,K<N,1<K<=P

Output

仅包含一个整数,表示满足要求的方案数对30031取模的结果。

Sample Input

样例一:10 3 3
样例二:5 2 3
样例三:10 2 4

Sample Output

1
3
81

HINT

【样例说明】

样例一的可行方案如下: (1,4,7,10),(2,5,8),(3,6,9)

样例二的可行方案如下: (1,3,5),(2,4) (1,3,4),(2,5) (1,4),(2,3,5) 

P<=10 , K <=8
题解Here!

看到P<=10,立马想到状压DP。

然后本蒟蒻就不会了,还是太菜了。。。

注意看这样一种路线:

A B C _ _->_ B C A _->_ B _ A C

A B C _ _->A B _ _ C->_ B _ A C

虽然顺序不同,但是他们是同一种方案,都是A由1到4,C由3到5。

所以我们不妨强制要求必须得最靠前的先走。

这样一来就可以转移了。

一个P位的二进制位,恰好有k个1且最高位为1表示状态。

我刚开始不明白为什么恰好有k个,也不明白为什么最高位为1,想到那个强制要求之后就懂了。

所以合法状态最多有C94=126。

但是那个N<=109怎么解?

时间复杂度最高也只有log2N×C94

注:矩阵乘法真是个玄学的东东。。。

附代码:

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#define MAXN 210
#define MOD 30031
using namespace std;
int n,m=0,k,p;
int bit[20],val[MAXN];
struct Matrix{
    long long a[MAXN][MAXN];
}base,ans;
inline int read(){
    int date=0,w=1;char c=0;
    while(c<'0'||c>'9'){if(c=='-')w=-1;c=getchar();}
    while(c>='0'&&c<='9'){date=date*10+c-'0';c=getchar();}
    return date*w;
}
Matrix operator *(Matrix x,Matrix y){
    Matrix ret;
    for(int i=1;i<=m;i++)
    for(int j=1;j<=m;j++){
        ret.a[i][j]=0;
        for(int k=1;k<=m;k++){
        	ret.a[i][j]+=x.a[i][k]*y.a[k][j]%MOD;;
        	ret.a[i][j]%=MOD;
        }
    }
    return ret;
}
Matrix mexp(int b){
    Matrix s;
    for(int i=1;i<=m;i++)s.a[i][i]=1;
    while(b){
        if(b&1)s=s*base;
        base=base*base;
        b>>=1;
    }
    return s;
}
inline int lowbit(int x){return x&(-x);}
void dfs(int x,int s,int v){
    if(s==k){
        val[++m]=v;
        return;
    }
    for(int i=x-1;i;i--)dfs(i,s+1,v+bit[i-1]);
}
void work(){
    ans.a[1][1]=1;
    base=mexp(n-k);
    ans=ans*base;
    printf("%lld
",ans.a[1][1]);
}
void init(){
    n=read();k=read();p=read();
    bit[0]=1;
    for(int i=1;i<=19;i++)bit[i]=bit[i-1]<<1;
    dfs(p,1,bit[p-1]);
    for(int i=1;i<=m;i++)
    for(int j=1;j<=m;j++){
        int x=(val[i]<<1)^bit[p]^val[j];
        if(x==lowbit(x))base.a[i][j]=1;
    }
}
int main(){
    init();
    work();
    return 0;
}
原文地址:https://www.cnblogs.com/Yangrui-Blog/p/9417294.html