hdu 2243(Trie图(AC自动机)+DP+矩阵乘法+恶心取模)

题意:容易理解...

思路:就是poj 2778的加强版吧,实现的思想差不多,在做这道题之前建议先做poj 1625、poj 2778,这三道题是同一种题型吧!出现过给定的单词的单词数=总单词数-未出现过给定单词的单词数,前者等于26^1+26^2+26^3....26^l,后者等于A^1+A^2+..A^L,这里的A是根据能走的字符之间的路径数建立出来的邻接矩阵。两者求出来一减就可以了,对2^64取模可以忽略,直接用unsigned long long做,忽视溢出就等于取模了(这个地方很坑爹啊),注意减法最后结果可能小于0,要加2^64变成正数。求a^1+..a^n这是矩阵乘法中关于等比矩阵的求法:|A  E|

                                  |0  E|

其中的A为m阶矩阵,就是通过Trie图求得的,E是单位矩阵,0是零矩阵。而我们要求的是:                                                                              

A^1+A^2+..A^L,由等比矩阵的性质

|A  ,  1|                 |A^n , 1+A^1+A^2+....+A^(n-1)| 

|0  ,  1| 的n次方等于|0     ,                                       1  | 

所以我们只需要将A矩阵扩大四倍,变成如上形式的矩阵B,然后开L+1次方就可以得到1+A^1+A^2+....+A^L。由于多了一个1,所以最后得到的答案我们还要减去1。

代码实现:

#include<iostream>
#include<string.h>
#include<queue>
#include<algorithm>
using namespace std;
#define ll unsigned __int64
struct node{
    int flag;
    int fail;
    int next[26];
    void init()
    {
        flag=0;
        fail=0;
        memset(next,0,sizeof(next));
    }
}s[32];
struct yun{
    ll num[32*2][32*2];
    int i,j;
    void yunsi()
    {
       for(i=0;i<64;i++)
           for(j=0;j<64;j++)
               num[i][j]=0;
    }
};
ll mod=10330176681277348905;
struct yun a;
int n,m,tot;
void ca()
{
    tot=0;
    s[0].init();
}
void insert(char *str)//字典树
{
    int index,p=0;
    for(;*str!='\0';str++)
    {
        index=*str-'a';
        if(s[p].next[index]==0)
        {
            s[++tot].init();
            s[p].next[index]=tot;
        }
        p=s[p].next[index];
    }
    s[p].flag=1;
}
void AC_tree()//Trie图
{
    queue<int>Q;
    int p,son,cur,temp,i;
    s[0].fail=0;
    Q.push(0);
    while(!Q.empty())
    {
        p=Q.front();
        Q.pop();
        if(s[p].flag==1)
            continue;
        for(i=0;i<26;i++)
        {
            son=s[p].next[i];
            if(son!=0)
            {
                if(p==0)
                    s[son].fail=0;
                else
                {
                    cur=s[p].fail;
                    while(cur!=0&&s[cur].next[i]==0)
                        cur=s[cur].fail;
                    s[son].fail=s[cur].next[i];
                }
                if(s[s[son].fail].flag)
                    s[son].flag=1;
                if(s[son].flag==0)
                    a.num[p][son]++;
                Q.push(son);
            }
            else
            {
                temp=s[s[p].fail].next[i];
                s[p].next[i]=temp;
                if(s[temp].flag==0)
                   a.num[p][temp]++;
            }
        }
    }
}
struct yun mul(struct yun t1,struct yun t2,int flag)//矩阵相乘
{
    int x,i,j,k;
    struct yun temp;
    if(flag==0)
        x=2;
    else
        x=(tot+1)*2;
    for(i=0;i<x;i++)
        for(j=0;j<x;j++)
        {
            temp.num[i][j]=0;
            for(k=0;k<x;k++)
                if(t1.num[i][k]>0&&t2.num[k][j]>0)
                    temp.num[i][j]+=t1.num[i][k]*t2.num[k][j];
        }
    return temp;
}
ll sum()//26^1+26^2+...+26^m
{
    struct yun x,y;
    int t=m;
    x.num[0][0]=y.num[0][0]=26;
    x.num[0][1]=y.num[0][1]=1;
    x.num[1][0]=y.num[1][0]=0;
    x.num[1][1]=y.num[1][1]=1;
    while(t)
    {
        if(t%2==1)
            y=mul(y,x,0);
        x=mul(x,x,0);
        t=t/2;
    }
    return y.num[0][1]-1;
}
ll solve()//A^1+A^2+...A^m
{
    int i,j,len=tot+1;
    ll ans=0;
    struct yun b,c;
    b.yunsi();
    c.yunsi();
    for(i=0;i<len;i++)
        for(j=0;j<len;j++)
            b.num[i][j]=c.num[i][j]=a.num[i][j];
    for(i=0;i<len;i++)
        for(j=len;j<len*2;j++)
            b.num[i][j]=c.num[i][j]=(i==j-len);
    for(i=len;i<len*2;i++)
        for(j=len;j<len*2;j++)
            b.num[i][j]=c.num[i][j]=(i==j);
    while(m)
    {
        if(m%2==1)
        {
            b=mul(b,c,1);
        }
        c=mul(c,c,1);
        m=m/2;
    }
    for(i=len;i<len*2;i++)
        ans+=b.num[0][i];
    return ans-1;
}
int main()
{
    char str[6];
    ll sum1,sum2;
    while(scanf("%d%d",&n,&m)!=EOF)
    {
       getchar();
       ca();//初始化
       a.yunsi();
       sum1=0;
       sum2=0;
       while(n--)
       {
           scanf("%s",str);
           insert(str);
       }
       AC_tree();
       sum1=sum();
       sum2=solve();
       sum1=sum1-sum2;
       if(sum1<0)//这里需要注意下
           sum1+=mod;
       printf("%I64u\n",sum1);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/jiangjing/p/3061382.html