Description
给 (n) 个基本词汇, (m) 个禁忌词语。求用基本词汇(每个词汇可重复词汇)拼成长度为 (L) 的
不包含任何禁忌词语的字符串的方案数。
Solution
在 数据规模与约定 中,我们发现可以把数据划分成两档:
-
(L le 100) 的(前 (60pts)) 。
-
基本长度不超过 (2) 的
第一档 60pts
显然不包含这个东西判定可以用 AC 自动机,用 (m) 个禁忌词语建 AC 自动机,把非法点标记一下(即所有词语末尾及其在 fail 树上的子树)。
然后方案数这个东西显然的朴素 DP:
- 设 (f[i][j]) 为前 (i) 个字符,当前在 AC 自动机上的编号是 (j)
考虑转移就是拼接一个基本词汇,假设目前在 (f[i][u]),设词汇长度为 (l),就让这个词汇从 (u) 出发在 AC 自动机上跑全串,然后落在 (v) 节点。要求这一段路径上都不能走非法点才可以转移。这样 (f[i][u]) 就对 (f[i + l][v]) 有了加的贡献。
每次转移这个过程可以先用 (O(100 ^ 2N))枚举 (u) 和基本词汇来确定是否为合法转移和转移到的图上节点。
然后 DP 复杂度是 (O(100NL)) 的,可以跑过。
下面 $60pts $ 的代码
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
const int N = 55, S = 105, P = 1e9 + 7;
int n, m, L, q[S], g[S][N], f[S][S], len[N];
int idx = 0, tr[S][26], fail[S];
bool e[S];
char a[N][S], b[N][S];
void insert(char s[]) {
int p = 0;
for (int i = 0; s[i]; i++) {
int ch = s[i] - 'a';
if (!tr[p][ch]) tr[p][ch] = ++idx;
p = tr[p][ch];
}
e[p] = true;
}
void build() {
int hh = 0, tt = -1;
for (int i = 0; i < 26; i++)
if (tr[0][i]) q[++tt] = tr[0][i];
while (hh <= tt) {
int u = q[hh++];
for (int i = 0; i < 26; i++) {
int v = tr[u][i];
if (v) {
fail[v] = tr[fail[u]][i];
if (e[fail[v]]) e[v] = true;
q[++tt] = v;
} else tr[u][i] = tr[fail[u]][i];
}
}
}
void work(int u, int x) {
int p = u;
if (e[p]) { g[u][x] = -1; return; }
for (int i = 0; i < len[x]; i++) {
p = tr[p][a[x][i] - 'a'];
if (e[p]) { g[u][x] = -1; return; }
}
g[u][x] = p;
}
void solve1() {
for (int u = 0; u <= idx; u++)
for (int i = 1; i <= n; i++) work(u, i);
f[0][0] = 1;
for (int i = 0; i < L; i++) {
for (int u = 0; u <= idx; u++) {
if (e[u] || !f[i][u]) continue;
for (int j = 1; j <= n; j++) {
if (g[u][j] != -1) {
(f[i + len[j]][g[u][j]] += f[i][u]) %= P;
}
}
}
}
int ans = 0;
for (int i = 0; i <= idx; i++)
if (!e[i]) (ans += f[L][i]) %= P;
printf("%d
", ans);
}
int main() {
scanf("%d%d%d", &n, &m, &L);
for (int i = 1; i <= n; i++) scanf("%s", a[i]), len[i] = strlen(a[i]);
for (int i = 1; i <= m; i++) {
scanf("%s", b[i]);
insert(b[i]);
}
build();
solve1();
return 0;
}
第二档
看到 (L le 10^8),我们的复杂度显然不能跟 (L) 相关。
先考虑长度都是 (1) 的时候,那么转移过程中,(f[i]) 这层只会对 (f[i + 1]) 这层产生贡献,而且每次产生贡献的方式是相同的、且是加和贡献,这个东西我们是会用矩阵乘法优化的,即用每一层作为一个矩阵,每一次迭代计算下一层结果,类似 HNOI2008 GT考试。
接着用相同的思想考虑长度 $ le 2$ 的情况,那么 (f[i]) 这层只会对 (f[i + 1], f[i + 2]) 产生影响,这种情况仍可以用矩阵乘法优化,即每一个矩阵记录两层数据,但是要更复杂一点。
设计矩阵为 ([F_{i,1}, F_{i, 2},...,F_{i, L}, F_{i + 1, 1}, F_{i + 1, 2}, ...,F_{i + 1, L} ])
考虑把 ([F_i, F_{i + 1}] imes A = [F_{i + 1}, F_{i + 2}]),构造一个矩阵 (A)。
- 考虑对 (F_{i + 1}) 的贡献:首先右侧的 (F_{i + 1}) 平移到左边了,所以矩阵左下方应该是一个单位矩阵
- (F_{i}) 对 (F_{i + 2}) 的贡献: 就是那些长度为 (2) 的基本词汇转移, (u Rightarrow v) 产生 (1) 的贡献
- (F_{i + 1}) 对 (F_{i + 2}) 的贡献同理。
一些细节问题:
- 注意下标之间的调整
- 一开始矩阵是 ([F_0, F_1]),所以要先计算 (F[1]) 这层 (就是找那些长度为 (1) 的转移)。
这个时间复杂度是 (O(100^3logL)) 的。
总代码 100pts
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
typedef long long LL;
const int N = 55, S = 105, P = 1e9 + 7;
int n, m, L, q[S], g[S][N], f[S][S], len[N];
int idx = 0, tr[S][26], fail[S];
bool e[S];
char a[N][S], b[N][S];
// 矩阵
struct Matrix{
int n, m, w[S << 1][S << 1];
Matrix operator * (const Matrix &b) const {
Matrix c; memset(c.w, 0, sizeof c.w);
c.n = n, c.m = b.m;
for (int i = 1; i <= n; i++)
for (int j = 1; j <= b.m; j++)
for (int k = 1; k <= m; k++)
c.w[i][j] = (c.w[i][j] + (LL)w[i][k] * b.w[k][j]) % P;
return c;
}
} A, res;
// AC 自动机
void insert(char s[]) {
int p = 0;
for (int i = 0; s[i]; i++) {
int ch = s[i] - 'a';
if (!tr[p][ch]) tr[p][ch] = ++idx;
p = tr[p][ch];
}
e[p] = true;
}
void build() {
int hh = 0, tt = -1;
for (int i = 0; i < 26; i++)
if (tr[0][i]) q[++tt] = tr[0][i];
while (hh <= tt) {
int u = q[hh++];
for (int i = 0; i < 26; i++) {
int v = tr[u][i];
if (v) {
fail[v] = tr[fail[u]][i];
if (e[fail[v]]) e[v] = true;
q[++tt] = v;
} else tr[u][i] = tr[fail[u]][i];
}
}
}
// 预处理转移方式
void work(int u, int x) {
int p = u;
if (e[p]) { g[u][x] = -1; return; }
for (int i = 0; i < len[x]; i++) {
p = tr[p][a[x][i] - 'a'];
if (e[p]) { g[u][x] = -1; return; }
}
g[u][x] = p;
}
// 前 60 pts
void solve1() {
f[0][0] = 1;
for (int i = 0; i < L; i++) {
for (int u = 0; u <= idx; u++) {
if (e[u] || !f[i][u]) continue;
for (int j = 1; j <= n; j++) {
if (g[u][j] != -1) {
(f[i + len[j]][g[u][j]] += f[i][u]) %= P;
}
}
}
}
int ans = 0;
for (int i = 0; i <= idx; i++)
if (!e[i]) (ans += f[L][i]) %= P;
printf("%d
", ans);
}
// 获取 id
int num(int id, int c) {
return id + 1 + c * (idx + 1);
}
// 后 40 pts
void solve2() {
res.n = 1, res.m = A.n = A.m = (idx + 1) * 2;
res.w[1][num(0, 0)] = 1;
for (int i = 1; i <= n; i++) {
if (g[0][i] != -1) {
int v = g[0][i];
if (len[i] == 1) res.w[1][num(v, 1)]++; // 计算 F[1] 这层
}
}
// 构造 A
for (int u = 0; u <= idx; u++) {
A.w[num(u, 1)][num(u, 0)] ++;
for (int i = 1; i <= n; i++) {
if (g[u][i] != -1) {
int v = g[u][i];
if (len[i] == 1) A.w[num(u, 1)][num(v, 1)]++;
else A.w[num(u, 0)][num(v, 1)]++;
}
}
}
int b = L;
while (b) {
if (b & 1) res = res * A;
A = A * A;
b >>= 1;
}
int ans = 0;
for (int i = 0; i <= idx; i++)
if (!e[i]) (ans += res.w[1][i + 1]) %= P;
printf("%d
", ans);
}
int main() {
scanf("%d%d%d", &n, &m, &L);
for (int i = 1; i <= n; i++) scanf("%s", a[i]), len[i] = strlen(a[i]);
for (int i = 1; i <= m; i++) {
scanf("%s", b[i]);
insert(b[i]);
}
build();
for (int u = 0; u <= idx; u++)
for (int i = 1; i <= n; i++) work(u, i);
if (L <= 100) solve1();
else solve2();
return 0;
}