noip模拟赛 单词

分析:这道题真心难想.最主要的是怎么样不重复.

      为了不重复统计,把所有符合条件的单词分成两类,一类是某些单词的前缀,一类是 不是任何单词的前缀.涉及到前缀后缀,维护两个trie树,处理3个数组a,b,c. a[i][j]表示长度为i-1的前缀,第i位接字母j是不是任何单词的前缀的个数. b[i][j]表示长度为i,最后一个字母为j,并且不是词典中单词的前缀的个数.c[i][j]表示长度为i,第一个字母为j的后缀的个数.

      先统计每个单词本身.再来考虑每个单词除了自身外的前缀.比如一个单词abcd,它的前缀有abc,ab,a.现在的任务就是看能不能拼出它们.比较棘手的一个问题就是每一个单词可以从多个位置划分,abc可以划分成ab c,也可以划分成a bc,为了不重复统计同一个单词,强行规定划分最后面的一个字符.

因为这一类单词都是前缀+后缀拼接起来的,所以划分出来的最后一个字符一定要作为某个单词的后缀,整个单词必须是词典中某个单词的前缀,这是由分类决定的,利用b,c两个数组能统计出答案,由于b数组的定义,保证了不会将一个词典中出现过的单词统计两次.

      还有一类是 不是任意单词的前缀的单词.利用a,c两个数组来统计.两个单词abce,cd. 

a[4][d]*c[1][d]就是这一类单词有多少个.需要意会一下,两类的答案相加就是答案了.

      为了使统计不重复,可以把所有的元素划分为若干个没有交集的集合,分别统计.

#include <cstdio>
#include <cmath>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

const int mod = 1e9 + 7;

int n, q, a[60][60], b[60][60], c[60][60];
long long cnt[120];
char s[60];

struct node
{
    int tree[500010][30];
    int flag[500010];
    int tot;
    node() {
        memset(tree, 0, sizeof(tree));
        memset(flag, 0, sizeof(flag));
        tot = 0;
    }
    void insert(char *s)
    {
        int id = 0;
        for (int i = 0; s[i]; i++)
        {
            if (tree[id][s[i] - 'a'])
                id = tree[id][s[i] - 'a'];
            else
                id = tree[id][s[i] - 'a'] = ++tot;
        }
        flag[id]++;
    }
    void dfs(int x, int l)
    {
        for (int i = 0; i < 26; i++)
        {
            if (tree[x][i] == 0 && x)
                a[l][i]++;
            if (tree[x][i] && x && !flag[tree[x][i]])
                b[l + 1][i]++;
            if (tree[x][i])
                dfs(tree[x][i], l + 1);
        }
    }
    void dfs2(int x, int l)
    {
        for (int i = 0; i < 26; i++)
            if (tree[x][i])
            {
                c[l + 1][i]++;
                dfs2(tree[x][i], l + 1);
            }
    }
}t1,t2;

int main()
{
    scanf("%d%d", &n, &q);
    for (int i = 1; i <= n; i++)
    {
        scanf("%s", s);
        int len = strlen(s);
        t1.insert(s);
        reverse(s, s + len);
        t2.insert(s);
        cnt[len]++;
    }
    t1.dfs(0,0);
    t2.dfs2(0,0);
    for (int i = 1; i <= 50; i++)
        for (int j = 0; j < 26; j++)
            if (c[1][j])
                cnt[i] += b[i][j];
    for (int i = 1; i <= 50; i++)
        for (int j = 1; j <= 50; j++)
            for (int k = 0; k < 26; k++)
                cnt[i + j] += 1LL * a[i][k] * c[j][k];
    int l;
    while (q--)
    {
        scanf("%d", &l);
        printf("%lld
", cnt[l] % mod);
    }

    return 0;
}
原文地址:https://www.cnblogs.com/zbtrs/p/7778824.html