【LOJ6074】【2017 山东一轮集训 Day6】子序列 DP

题目描述

  有一个由前 (m) 个小写字母组成的串 (S),有 (q) 个询问,每次给你 (l,r),问你 (S_{lldots r}) 有多少个非空子序列。

  (m=9,n=lvert S vert leq {10}^5,qleq {10}^5)

题解

  题解接下来的部分求得答案是包含空串的答案。

  最简单的做法是DP。

  设 (f_{i,j}) 为前 (i) 个字符,末尾为 (j) 的子序列个数。

  特殊的,如果 (j=m+1) 就说明当前还没有选任何字符。

  递推式为

[f_{i,j}= egin{cases} f_{i-1,j}&,j eq S_i\ sum_{k=1}^{m+1}f_{i-1,k}&,j=S_i end{cases} ]

  答案为 (sum_{i=1}^{m+1}f_{n,i})

  复杂度为 (O(nq))

  注意到这个转移是一个矩阵乘法的形式,记

[F_i= egin{pmatrix} f_{i,1}\f_{i,2}\vdots\f_{i,m+1} end{pmatrix} ]

  那么

[egin{align} F_i&=A_iF_{i-1}\ A_i&= egin{pmatrix} 1&&&&\ &1&&&\ 1&1&1&1&1\ &&&1&\ &&&&1 end{pmatrix} end{align} ]

  这里 (A_i) 就是把单位矩阵的第 (S_i) 行全部设为一得到的矩阵。

  记

[egin{align} U&= egin{pmatrix} 1\1\vdots\1 end{pmatrix}\ V&= egin{pmatrix} 0\0\vdots\1 end{pmatrix} end{align} ]

  那么我们最终的答案就是

[egin{align} &VA_rA_{r-1}cdots A_lU\ =&VA_rA_{r-1}cdots A_1{A_1}^{-1}{A_2}^{-1}ldots A_{l-1}U end{align} ]

  我们可以维护一个 (A) 的前缀积和 (A) 的逆的前缀积即可。需要注意矩阵乘法的顺序。

  时间复杂度:(O(nm^3+qm^2))

  可以发现,做矩阵乘法的时候只有一行有变化,那么预处理的矩阵乘法的复杂度可以降到 (O(nm^2))。总复杂度为 (O((n+q)m^2))

  其实这道题还能进一步优化。

  对于询问,我们只需要在预处理的时候把矩阵乘以 (U)(V) 的结果保存下来,就可以做到 (O(qm))

  对于预处理,求 (A_rA_{r-1}cdots A_1) 的时候左乘转移矩阵的时候实际上是把 (S_i) 这一行中每个位置的值改为这一列所有元素的和,直接维护一下每列的和就好了。右乘转移矩阵的逆矩阵就是对于每一行,除了 (S_i) 这一列外其他所有列都减掉这个位置。那么可以维护一下这一列所有元素共同减掉的数,然后修改一下这个位置单点的值。

  具体来说,假设原来的矩阵的某一行 (0) 是长这样:

[egin{pmatrix} a_1-v&a_2-v&a_3-v&a_4-v\ end{pmatrix} ]

  对第三个位置操作后就会变成

[egin{pmatrix} a_1-a_3&a_2-a_3&(2a_3-v)-a_3&a_4-a_3 end{pmatrix} ]

  这样预处理的复杂度就降到了 (O(nm))

  总复杂度为 (O((n+q)m))

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<functional>
#include<cmath>
#include<vector>
//using namespace std;
using std::min;
using std::max;
using std::swap;
using std::sort;
using std::reverse;
using std::random_shuffle;
using std::lower_bound;
using std::upper_bound;
using std::unique;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef std::pair<int,int> pii;
typedef std::pair<ll,ll> pll;
void open(const char *s){
#ifndef ONLINE_JUDGE
	char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout);
#endif
}
int rd(){int s=0,c,b=0;while(((c=getchar())<'0'||c>'9')&&c!='-');if(c=='-'){c=getchar();b=1;}do{s=s*10+c-'0';}while((c=getchar())>='0'&&c<='9');return b?-s:s;}
void put(int x){if(!x){putchar('0');return;}static int c[20];int t=0;while(x){c[++t]=x%10;x/=10;}while(t)putchar(c[t--]+'0');}
int upmin(int &a,int b){if(b<a){a=b;return 1;}return 0;}
int upmax(int &a,int b){if(b>a){a=b;return 1;}return 0;}
const int p=1000000007;
const int N=100010;
const int M=10;
ll fp(ll a,ll b)
{
	ll s=1;
	for(;b;b>>=1,a=a*a%p)
		if(b&1)
			s=s*a%p;
	return s;
}
const ll inv2=fp(2,p-2);
int plus(int a,int b)
{
	a+=b;
	return a>=p?a-p:a;
}
int plus2(int a)
{
	a+=a;
	return a>=p?a-p:a;
}
int minus(int a,int b)
{
	a-=b;
	return a<0?a+p:a;
}
int c[N];
int n,q;
int f1[N][10];
int a1[10][10];
int f2[N][10];
int a2[10][10];
char str[N];
void init()
{
	for(int i=0;i<=9;i++)
		a1[i][i]=a2[i][i]=f1[0][i]=1;
	for(int i=1;i<=n;i++)
	{
		int v=str[i]-'a';
		for(int j=0;j<=9;j++)
		{
			f1[i][j]=minus(plus2(f1[i-1][j]),a1[v][j]);
			a1[v][j]=f1[i-1][j];
			f2[i][j]=a2[v][j];
			a2[v][j]=minus(plus2(a2[v][j]),f2[i-1][j]);
		}
	}
}
int main()
{
	scanf("%s",str+1);
	n=strlen(str+1);
	scanf("%d",&q);
	init();
	int l,r;
	for(int i=1;i<=q;i++)
	{
		l=rd();
		r=rd();
		int ans=f1[r][9]-1;
		for(int j=0;j<=8;j++)
			ans=(ans-(ll)f1[r][j]*f2[l-1][j])%p;
		ans=plus(ans,p);
		printf("%d
",ans);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/ywwyww/p/9245582.html