HDU5845 trie树优化dp

http://acm.hdu.edu.cn/showproblem.php?pid=5845

题意:给定序列,问最多可以分成多少段序列使得每段序列不超过L且异或和不超过X

首先对于区间异或和,很容易想到前缀异或和去优化使其可以在O(1)时间内求出区间异或和,然后我们就可以写出一个n²暴力

#include <map>
#include <set>
#include <ctime>
#include <cmath>
#include <queue>
#include <stack>
#include <vector>
#include <string>
#include <bitset>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <sstream>
#include <iostream>
#include <algorithm>
#include <functional>
using namespace std;
#define For(i, x, y) for(int i=x;i<=y;i++)  
#define _For(i, x, y) for(int i=x;i>=y;i--)
#define Mem(f, x) memset(f,x,sizeof(f))  
#define Sca(x) scanf("%d", &x)
#define Sca2(x,y) scanf("%d%d",&x,&y)
#define Sca3(x,y,z) scanf("%d%d%d",&x,&y,&z)
#define Scl(x) scanf("%lld",&x)  
#define Pri(x) printf("%d
", x)
#define Prl(x) printf("%lld
",x)  
#define CLR(u) for(int i=0;i<=N;i++)u[i].clear();
#define LL long long
#define ULL unsigned long long  
#define mp make_pair
#define PII pair<int,int>
#define PIL pair<int,long long>
#define PLL pair<long long,long long>
#define pb push_back
#define fi first
#define se second 
typedef vector<int> VI;
int read(){int x = 0,f = 1;char c = getchar();while (c<'0' || c>'9'){if (c == '-') f = -1;c = getchar();}
while (c >= '0'&&c <= '9'){x = x * 10 + c - '0';c = getchar();}return x*f;}
const double PI = acos(-1.0);
const double eps = 1e-9;
const int maxn = 1e5 + 10;
const int INF = 0x3f3f3f3f;
const int mod = 268435456; 
LL N,X,L,P,Q;
LL a[maxn],dp[maxn];
LL pre[maxn];
LL sum(int i,int j){
    return pre[j] ^ pre[i - 1];
}
int main(){
    int T; Sca(T);
    while(T--){
        scanf("%lld%lld%lld",&N,&X,&L);
        scanf("%lld%lld%lld",&a[1],&P,&Q);
        for(int i = 2; i <= N ; i ++){
            a[i] = ((a[i - 1] * P) + Q) % mod;
        }
        pre[0] = 0;
        for(int i = 1; i <= N ; i ++) pre[i] = pre[i - 1] ^ a[i];
        for(int i = 0; i <= N ; i ++) dp[i] = 0;
        for(int i = 1; i <= N; i ++){
            for(int j = max(0LL,i - L); j < i ; j ++){
                if(sum(j + 1,i) <= X) dp[i] = max(dp[i],dp[j] + 1);
            }
        }
        Prl(dp[N]);
    }
    return 0;
}
n²暴力

我们可以发现对于pre相同的下标而言,dp的大小呈单调性,即i > j 且pre[i] = pre[j] 则dp[i] > dp[j],由于i,j之间异或和为0,显然dp[i] - dp[j] >= 1

那么对于前面长度L的区间,我们可以考虑用字典树优化,用01字典树维护每个前缀和的dp最大值,由于满足单调性,对于字典树上的删除我们只需要维护每个节点出现的次数,因为只要字典树上还存在当前节点(出现次数不为0),就意味着当前最大值不会变(最大值永远越后面的越大)

对于查询的时候就需要讨论,如果当前位X为0,说明查询的pre当前位上也是0,需要走当前位与他相同的路径,如果X为1,那么可以走与当前位相反的路径使得该位和X一样为1,或者走与其相同的路径使得该位为0,倘若走0的路径,那么直接取子树的最大值不用继续往下走,因为下面无论怎么走都一定比X小

#include <map>
#include <set>
#include <ctime>
#include <cmath>
#include <queue>
#include <stack>
#include <vector>
#include <string>
#include <bitset>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <sstream>
#include <iostream>
#include <algorithm>
#include <functional>
using namespace std;
#define For(i, x, y) for(int i=x;i<=y;i++)  
#define _For(i, x, y) for(int i=x;i>=y;i--)
#define Mem(f, x) memset(f,x,sizeof(f))  
#define Sca(x) scanf("%d", &x)
#define Sca2(x,y) scanf("%d%d",&x,&y)
#define Sca3(x,y,z) scanf("%d%d%d",&x,&y,&z)
#define Scl(x) scanf("%lld",&x)  
#define Pri(x) printf("%d
", x)
#define Prl(x) printf("%lld
",x)  
#define CLR(u) for(int i=0;i<=N;i++)u[i].clear();
#define LL long long
#define ULL unsigned long long  
#define mp make_pair
#define PII pair<int,int>
#define PIL pair<int,long long>
#define PLL pair<long long,long long>
#define pb push_back
#define fi first
#define se second 
typedef vector<int> VI;
int read(){int x = 0,f = 1;char c = getchar();while (c<'0' || c>'9'){if (c == '-') f = -1;c = getchar();}
while (c >= '0'&&c <= '9'){x = x * 10 + c - '0';c = getchar();}return x*f;}
const int maxn = 1e5 + 10;
const int maxm = 5e6 + 10;
const LL INF = 1e18;
const LL mod = 268435456; 
LL N,X,L,P,Q;
LL a[maxn],dp[maxn],pre[maxn];
int nxt[maxm][2],cnt,num[maxm];
LL val[maxm];
void insert(int j){
    LL x = pre[j],v = dp[j];
    int p = 1;
    for(int i = 32; i >= 0; i --){
        int id = (x >> i) & 1;
        if(!nxt[p][id]){
             nxt[p][id] = ++cnt;
             val[cnt] = -INF; num[cnt] = nxt[cnt][0] = nxt[cnt][1] = 0;
        }
        p = nxt[p][id];
        val[p] = max(val[p],v); num[p]++;
    }
}
void del(int p,int i,LL x){
    if(i == -1){if(!num[p]) val[p] = -INF;return;}
    int id = (x >> i) & 1;
    num[nxt[p][id]]--;
    del(nxt[p][id],i - 1,x);
    val[p] = val[nxt[p][id]];
    if(nxt[p][id ^ 1] && num[nxt[p][id ^ 1]] > 0) val[p] = max(val[nxt[p][0]],val[nxt[p][1]]);
}
LL query(LL x){
    int p = 1;
    LL ans = -INF;
    for(int i = 32; i >= 0 ; i --){
        int id = (x >> i) & 1;
        if((X >> i) & 1){
            if(nxt[p][id] && num[nxt[p][id]]){
                ans = max(ans,val[nxt[p][id]]);
            }
            if(nxt[p][id ^ 1] && num[nxt[p][id ^ 1]]){
                p = nxt[p][id ^ 1];
            }
        }else{
            if(!nxt[p][id] || !num[nxt[p][id]]) return ans;
            p = nxt[p][id];
        }
    }
    ans = max(ans,val[p]);
    return ans;
    return val[p];
}
int main(){
    int T; Sca(T); cnt = 1;
    while(T--){
        for(int i = 0 ; i <= cnt; i ++){val[i] = -INF; nxt[i][0] = nxt[i][1] = num[i] = 0;}
        scanf("%lld%lld%lld",&N,&X,&L); cnt = 1;
        scanf("%lld%lld%lld",&a[1],&P,&Q);
        for(int i = 2; i <= N ; i ++) a[i] = ((a[i - 1] * P) + Q) % mod;
        pre[0] = dp[0] = 0; insert(0);
        for(int i = 1; i <= N ; i ++) pre[i] = pre[i - 1] ^ a[i];
    //    For(i,1,N) cout << pre[i] << " ";
    //    cout << endl;
        for(int i = 1; i <= N; i ++){
            if(i - L - 1 >= 0 && dp[i - L - 1] >= 0) del(1,32,pre[i - L - 1]);
            dp[i] = query(pre[i]) + 1;
            if(dp[i] > 0) insert(i);
        }
        if(dp[N] < 0) dp[N] = 0;
        Prl(dp[N]);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/Hugh-Locke/p/11280544.html