hdu 6125 分组背包+状态压缩

题意:1---n 选择不超过k个数,使得他们的乘积不包含完全平方因子。

思路:考虑有相同因子的数不能同时选择,自然想到了分组,并且要加上一个维度K(来限制当前选了几个数字),但事实上一个数又可以属于很多组。。。

这里就可以使用状态压缩来压缩因子,{2,3,5,7,11,13,17,19},除去这些因子之后,就可以分组啦。

当然这些因子本身也是需要判断互不相交的,我们可以枚举状态i j 转移就是 (i&j==0)  dp[k][j|l]=(ls[k-1][l]+dp[k][j|l])%mod;

所以说,最后就是一个分组背包问题,组与组之间的合并(状态转移通过枚举状态暴力处理)。

PS。处理属于哪个组的时候要格外注意。。。要确保组内不能互相选择,组外的选择通过状态压缩解决。

代码:

#include<bits/stdc++.h>
using namespace std;
#define MEM(a,b) memset(a,b,sizeof(a))
#define bug puts("bug");
#define PB push_back
#define MP make_pair
#define X first
#define Y second
typedef unsigned long long ll;
typedef pair<int,int> pii;
const int maxn=4e5+10;
const int mod=1000000007;
using namespace std;
int t,m,n,k;
int p[8]={2,3,5,7,11,13,17,19};
int st[505],belong[505];
ll a[605][605],dp[605][605],ls[605][605];
vector<int> v[505];

ll sl(int N,int K){
    MEM(st,0);MEM(ls,0);MEM(a,0);
    ll ret=0;
    for(int i=1;i<=N;i++) belong[i]=i,v[i].clear();
    for(int i=1;i<=N;i++){
        for(int j=0;j<8;j++){
            if(st[i]!=-1&&i%p[j]==0&&i%(p[j]*p[j])!=0)
                st[i]|=(1<<j),belong[i]/=p[j];
            else if(i%p[j]==0&&i%(p[j]*p[j])==0){
                st[i]=-1;
                break;
            }
        }
    }
    for(int i=1;i<=N;i++)
        if(st[i]!=-1){
            if(belong[i]==1) v[i].PB(st[i]);
            else v[belong[i]].PB(st[i]);
        }
    for(int i=1;i<=N;i++){
        for(int id=0;id<v[i].size();id++)
            a[i][v[i][id]]++;
    for(int i=1;i<=N;i++){
        if(v[i].size()==0)continue;
        for(int k=1;k<=K;k++)
            for(int j=0;j<(1<<8);j++)
                dp[k][j]=ls[k][j];
        for(int j=0;j<(1<<8);j++){
            if(a[i][j]==0) continue;
            dp[1][j]=(dp[1][j]+1)%mod;
            for(int k=1;k<=K;k++){
                for(int l=0;l<(1<<8);l++)
                    if((j&l)==0)dp[k][j|l]=(ls[k-1][l]+dp[k][j|l])%mod;
        }
        for(int k=1;k<=K;k++){
            for(int j=0;j<(1<<8);j++){
                ls[k][j]=dp[k][j];
                dp[k][j]=0;
            }
        }
    }
    for(int k=1;k<=K;k++){
        for(int j=0;j<(1<<8);j++)
            ret=(ret+ls[k][j])%mod;
    return ret;
}


int main(){
    scanf("%d",&t);
    while(t--){
        scanf("%d%d",&n,&k);
        printf("%lld
",sl(n,k));
    }
    return 0;
}



原文地址:https://www.cnblogs.com/zhangxianlong/p/10672501.html