扩展KMP --- HDU 3613 Best Reward

 Best Reward

Problem's Link:   http://acm.hdu.edu.cn/showproblem.php?pid=3613


Mean: 

给你一个字符串,每个字符都有一个权值(可能为负),你需要将这个字符串分成两个子串,使得这两个子串的价值之和最大。一个子串价值的计算方法:如果这个子串是回文串,那么价值就是这个子串所有字符权值之和;否则价值为0。

analyse:

扩展KMP算法运用。
总体思路:
找出所有包含第一个字母的回文串和包含最后一个字母的回文串,然后O(n)扫一遍,每次判断第i个字母之前(包含第i个字母)的子串是否是回文,以及从第i个字母后的子串是否是回文,然后计算出答案,取最大值。
具体做法:
假设输入的字符串是"abcda"
构造串s1="abcda#adcba"
求s1的Next数组,得到了包含第一个字母的回文串的位置;
构造串s2="adcba#abcda"
求s2的Next数组,得到了包含最后一个字母的回文串的位置;
用两个flag数组标记这些位置,然后扫一遍就得答案了。
中间加一个'#'并后接反串的目的是:当整个串都是回文的时候能够被Next数组记录下。

Time complexity: O(nlogn)

Source code: 

 第一遍写,不够优化:

/*
* this code is made by crazyacking
* Time: 0MS
* Memory: 137KB
*/
#include <queue>
#include <cstdio>
#include <string>
#include <stack>
#include <cmath>
#include <set>
#include <map>
#include <cstdlib>
#include <climits>
#include <vector>
#include <iostream>
#include <algorithm>
#include <cstring>
#define  MAXN 500010*2
#define  LL long long
using namespace std;
int len;
int Next[MAXN],ne[MAXN];
int sum[MAXN];
vector<int> val;
bool flag1[MAXN],flag2[MAXN];
char s[MAXN],s1[MAXN],s2[MAXN],sr[MAXN];
void get_sum()
{
        sum[0]=val[s[0]-'a'];
        for(int i=1;i<len;++i)
                sum[i]=sum[i-1]+val[s[i]-'a'];
}
void get_s1()
{
        strcpy(s1,s);
        s1[len]='#';
        s1[len+1]='';
        strcat(s1,sr);
}
void get_s2()
{
        strcpy(s2,sr);
        s2[len]='#';
        s2[len+1]='';
        strcat(s2,s);
}

void get_Next(char s[])
{
        Next[0]=0;
        int s_len=strlen(s);
        for(int i=1,k=0;i<s_len;++i)
        {
                while(k!=0 && s[i]!=s[k])
                        k=Next[k-1];
                if(s[i]==s[k]) k++;
                Next[i]=k;
        }
}
int main()
{
        ios_base::sync_with_stdio(false);
        cin.tie(0);
        int Cas;
        cin>>Cas;
        while(Cas--)
        {
                val.clear();
                int cnt=26,t;
                while(cnt--)
                {
                        cin>>t,val.push_back(t);
                }
                scanf("%s",s);
                len=strlen(s);
                if(strlen(s)==1)
                {
                        printf("%d
",val[s[0]-'a']);continue;
                }
                get_sum();
                strcpy(sr,s);
                strrev(sr);
                get_s1();
                get_s2();
                memset(flag1,0,sizeof flag1);
                memset(flag2,0,sizeof flag2);
                get_Next(s1);
                int k=Next[2*len];
                while(k!=0)
                {
                        flag1[k-1]=1;
                        k=Next[k-1];
                }
                get_Next(s2);
                k=Next[2*len];
                while(k!=0)
                {
                        flag2[k-1]=1;
                        k=Next[k-1];
                }
                reverse(flag2,flag2+len);
                long long ans=INT_MIN;
                long long tmp=0;
                for(int i=0;i<len-1;++i)
                {
                        tmp=0;
                        if(flag1[i])
                        {
                                tmp+=sum[i];
                        }
                        if(flag2[i+1])
                        {
                                tmp=tmp+(sum[len-1]-sum[i]);
                        }
                        if(tmp>ans)
                                ans=tmp;

                }
                cout<<ans<<endl;

        }
        return 0;
}
/*

*/
View Code

优化后的代码:

/*
* this code is made by crazyacking
* Verdict: Accepted
* Submission Date: 2015-05-07-16.26
* Time: 0MS
* Memory: 137KB
*/
#include <queue>
#include <cstdio>
#include <set>
#include <string>
#include <stack>
#include <cmath>
#include <climits>
#include <map>
#include <cstdlib>
#include <iostream>
#include <vector>
#include <algorithm>
#include <cstring>
#define  LL long long
#define  ULL unsigned long long
using namespace std;
const int MAXN=500010;
int val[30],Next[MAXN*2],sum[MAXN];
char s[MAXN],s1[MAXN*2];
bool flag[2][MAXN];
void get_sum()
{
        int len=strlen(s);
        sum[0]=val[s[0]-'a'];
        for(int i=1;i<len;++i)
                sum[i]=sum[i-1]+val[s[i]-'a'];
}

void get_Next(char ss[])
{
        int len=strlen(ss);
        Next[0]=0;
        int k=0;
        for(int i=1;i<len;++i)
        {
                while(k!=0 && ss[i]!=ss[k])
                        k=Next[k-1];
                if(ss[i]==ss[k]) k++;
                Next[i]=k;
        }
}

void get_flag(int x)
{
        strcpy(s1,s);
        int len=strlen(s);
        s1[len]='#';
        strrev(s);
        strcat(s1+len+1,s);
        get_Next(s1);
        len=strlen(s1);
        int k=Next[len-1];
        while(k!=0)
        {
                flag[x][k-1]=1;
                k=Next[k-1];
        }
        memset(s1,0,sizeof s1);
}
int main()
{
        ios_base::sync_with_stdio(false);
        cin.tie(0);
        int Cas;
        scanf("%d",&Cas);
        while(Cas--)
        {
                for(int i=0;i<26;++i)
                        scanf("%d",&val[i]);
                scanf("%s",s);
                if(strlen(s)==1)
                {
                        printf("%d
",val[s[0]-'a']);continue;
                }
                get_sum();
                memset(flag,0,sizeof flag);
                get_flag(0);
                get_flag(1);
                int len=strlen(s);
                reverse(flag[1],flag[1]+len);
                long long ans=LLONG_MIN,tmp;
                for(int i=0;i<len-1;++i)
                {
                        tmp=0;
                        tmp=(flag[0][i]?sum[i]:0)+(flag[1][i+1]?sum[len-1]-sum[i]:0);
                        ans=ans>tmp?ans:tmp;
                }
                printf("%lld
",ans);
        }
        return 0;
}
/*

*/
View Code
原文地址:https://www.cnblogs.com/crazyacking/p/4483476.html