Codeforces 76C Mutation

Description

描述

给出一个由前 $k$ 个大写字母组成的字符串,每次可以花费 $t_i$ 代价删去该字符串中所有的第 $i$ 个大写字母,一个字符串的代价为该字符串每对相邻字符的代价,给出一 $k imes k$ 的矩阵表示两个相邻的大写字母的代价矩阵,问有多少种删除方案使得删除的代价以及剩余字符串的代价之和不超过 $T$ ,注意剩余字符串需非空。

输入

第一行三个整数 $n,k,T$ 表示字符串长度,所用字母种类以及代价上限,之后输入一个由前 $k$ 个大写字母组成的长度为 $n$ 的字符串,以及 $n$ 个整数 $t_i$ 表示删去第 $i$ 种字符的代价,最后输入一个 $k imes k$ 的矩阵 $m_{x,y}$ 表示相邻两字母的代价矩阵 ($1≤n≤2⋅10^5$,$1≤k≤22$,$1≤T≤2⋅10^9$,$1≤t_i,a_{x,y}≤10^9$)。

输出

输出满足条件的方案数。

样例

输入

5 3 13
BACAC
4 1 2
1 2 3
2 3 4
3 4 10

输出

5

解释

下面是合法的方案:
$ ext{BACAC} (11), ext{ACAC} (10), ext{BAA} (5), ext{B}(6), ext{AA}(4)$。

Solution

考虑前缀和。

用一个 $k$ 位 $01$ 集合表示删字母情况,$0$ 表示不删,$1$ 表示删。

令 $f_i$ 表示删去集合为 $i$ 后的字符串代价,如果我们求出了每一个 $f_i$,就可以像这样统计答案:

int ans = 0;
for(int i = 0; i < 1 << k; i++)
    if((i & all) == i /* 删去字符在集合中 */ && f[i] <= T /* 符合代价要求 */ && i != all /* 剩余字符串不为空 */)
        ans++;
printf("%d
", ans);

其中 $all$ 表示所有出现字符全集,$T$ 为代价上限。

以下为读入及预处理:

scanf("%d%d%d%s", &n, &k, &T, s);                         
for(int i = 0; i < n; i++)                                
{                                                         
    s[i] -= 'A';                                          
    all |= 1 << s[i];                                     
}                                                         
for(int i = 0; i < k; i++) scanf("%d", &t[i]);            
for(int i = 0; i < k; i++)                                
    for (int j = 0; j < k; j++) scanf("%d", &a[i][j]);    
memset(can, -1, sizeof(can));                             
for(int i = 0; i < k; i++) f[1 << i] = t[i];

将每一种字符的损耗 $t_i$ 存进 $f_{2^i}$ 里,最后使用前缀和累加就可以将其加入每一种情况了。

累加 $f_i$ 就是从每一个比 $f_i$ 少一个元素的子集里累加(当然还有自己本来存储的值)。

例如 $f_{10110} = f_{10110} + f_{00110} + f_{10010} + f_{10100}$。

for(int i = 0; i < k; i++)           
    for(int j = 0; j < 1 << k; j++)  
        if((j >> i) & 1)             
            f[j] += f[j ^ (1 << i)]; 

注意这里累加的其实是 $f_j$。

前方高能

下面考虑如何求出每一个 $f$ 。

用 $can_i$ 表示从最近的字符$i$,到当前字符之间的字符集(两端不含),$can$ 初始化为 $-1$,表示该字符自身还没出现过。

我们遍历这个字符串,对于每一个字符 $s_i$,将它视为删去字符集的右端点(不含),再枚举所有的 $k$ 个字符 $j$,将离 $s_i$ 最近的 $j$ 视为删去字符集的左端点(不含),则删去的字符状态即为它们之间的所有种类的字符。

注意,并不是所有的 $j$ 都可以和 $s_i$ 匹配,一个合法的 $j$ 应有如下条件:

1. 在 $s_i$ 之前出现了字符 $j$,代码中体现为 can[j] >= 0 ;
2. $can_j$ 中不包含字符 $j$ 和 $s_i$,代码中体现为 !((can[j] >> j) & 1) && !((can[j] >> s[i]) & 1) 。

如果合法,那么删去两者之间的字符集后,该两个字符将会产生一个新的代价 $a_{s_i,j}$,将这个代价加入 $f_{can_j}$ 中,但需要注意的是 $f_{can_j|2^{s_i}}$ 和 $f_{can_j|2^j}$ 这两个不应该存在这个代价的状态将在最终被累加这个代价,所以将它们都减去 $a_{s_i,j}$,那此时 $f_{can_j|2^j|2^{s_i}}$ 其实被减去了两遍代价,那么将它加上 $a_{s_i,j}$。

也就是一个容斥啦。

f[can[j]] += a[j][s[i]];                         
f[can[j] | (1 << j)] -= a[j][s[i]];              
f[can[j] | (1 << s[i])] -= a[j][s[i]];           
f[can[j] | (1 << j) | (1 << s[i])] += a[j][s[i]];

还有两个细节,一是一个 $s_i$ 处理完后,就要将该 $can_{s_i}$ 清空,因为出现了一个新的 $s_i$,就是 can[s[i]] = 0 。

二是在遍历 $k$ 个字符时,只要该字符出现了,则在容斥(也有可能不容斥)后要标记 $j$ 可以和 $s_i$ 匹配,即 can[j] |= (1 << s[i]) 。

有一个问题:初始代价好像并没有预处理啊? 

这是因为当访问到 $s_i$ 时 $can_{s_{i-1}}$ 必定为 $0$,这时就会把 $a_{s_i,s_{i-1}}$ 加入 $f_0$ 中啦,其它情况也同理。

那么 AC 代码就出来了,时间复杂度 $mathcal O(nk+k imes 2^k)$。

#include <cstdio>
#include <cstring>
const int K = 25;
int n, k, T, all, can[K], t[K], a[K][K], f[1 << K];
char s[200005];
int main() 
{
    scanf("%d%d%d%s", &n, &k, &T, s);
    for(int i = 0; i < n; i++)
    {
        s[i] -= 'A';
        all |= 1 << s[i];
    }
    for(int i = 0; i < k; i++) scanf("%d", &t[i]);
    for(int i = 0; i < k; i++)
        for (int j = 0; j < k; j++) scanf("%d", &a[i][j]);
    memset(can, -1, sizeof(can));
    for(int i = 0; i < k; i++) f[1 << i] = t[i];               
    for(int i = 0; i < n; i++) 
    {
        for(int j = 0; j < k; j++)
            if(can[j] >= 0) 
            {
                if(!((can[j] >> j) & 1) && !((can[j] >> s[i]) & 1)) 
                {
                    f[can[j]] += a[j][s[i]];
                    f[can[j] | (1 << j)] -= a[j][s[i]];
                    f[can[j] | (1 << s[i])] -= a[j][s[i]];
                    f[can[j] | (1 << j) | (1 << s[i])] += a[j][s[i]];
                }
                can[j] |= (1 << s[i]);
            }
        can[s[i]] = 0;
    }
    for(int i = 0; i < k; i++)
        for(int j = 0; j < 1 << k; j++)
            if((j >> i) & 1)
                f[j] += f[j ^ (1 << i)]; 
    int ans = 0;
    for(int i = 0; i < 1 << k; i++)
        if((i & all) == i && f[i] <= T && i != all)
            ans++;
    printf("%d
", ans);
    return 0;
}

参考文献:
v5zsq.CodeForces 76 C.Mutation(状压DP+容斥原理+高维前缀和).CSDN

原文地址:https://www.cnblogs.com/syksykCCC/p/CF76C.html