CF1073G Yet Another LCP Problem

#include<bits/stdc++.h>
typedef long long ll;
const int maxn = 400100;
char buf[maxn];
int n,q;
namespace sam{
	int c[maxn][26],mx[maxn],fa[maxn],r[maxn],la,tot;
	int n,q;
	inline void ins(int w,int id){
		int p=la,now=la=++tot; mx[now] = mx[p] + 1; r[id] = tot;
		for(;p && !c[p][w];p = fa[p]) c[p][w] = now;
		if(!p) fa[now] = 1;
		else {
			int q = c[p][w];
			if(mx[q] == mx[p] + 1) fa[now] = q;
			else {
				int x = ++ tot; fa[x] = fa[q]; mx[x] = mx[p] + 1;
				memcpy(c[x],c[q],sizeof c[x]); fa[q] = fa[now] = x;
				for(;p && c[p][w] == q;p = fa[p]) c[p][w] = x;
			}
		}
	}
}
struct T{ int to,nxt; }way[maxn];
int h[maxn],num;
inline void adde(int x,int y){ way[++num]={y,h[x]},h[x]=num; }
int size[maxn],dfn[maxn],fa[maxn],dep[maxn],st[20][maxn],idx;
inline void dfs(int x){
	size[x] = 1,st[0][dfn[x] = ++ idx] = x,dep[x] = dep[fa[x]] + 1;
	for(int i=h[x];i;i=way[i].nxt)
		fa[way[i].to]=x,dfs(way[i].to), size[x] += size[way[i].to];
	st[0][idx] = fa[x];
}
inline int min(int x,int y){ return dep[x] < dep[y] ? x : y; }
inline int lca(int x,int y){
	if(dfn[x] > dfn[y]) std::swap(x,y);
	const int lg = 31 - __builtin_clz(dfn[y] - dfn[x]);
	return x == y ? x : min(st[lg][dfn[x]],st[lg][dfn[y] - (1<<lg)]);
}
ll ans = 0;
namespace vt{
	struct T{ int to,nxt; }way[maxn << 1];
	int h[maxn],num;
	int c1[maxn],c2[maxn];
	inline void adde(int x,int y){ way[++num]={y,h[x]},h[x]=num; }
	inline void dfs(int x,int f = 0){
		for(int & i = h[x];i;i=way[i].nxt)
			dfs(way[i].to,x),c1[x] += c1[way[i].to],c2[x] += c2[way[i].to];
		ans += ll(sam::mx[x] - sam::mx[f]) * c1[x] * c2[x];
	}
}
inline int cmp(int x,int y){ return dfn[x] < dfn[y]; }
int main(){
	std::ios::sync_with_stdio(false),std::cin.tie(0);
	std::cin >> n >> q >> buf + 1; sam::la=sam::tot=1;
	for(int i=n;i>=1;--i) sam::ins(buf[i] - 'a',i);
	for(int i=2;i<=sam::tot;++i){
		adde(sam::fa[i],i);
	}
	dfs(1);
	for(int i=1;i<20;++i) for(int j=1;j+(1<<i)-1<sam::tot;++j)
		st[i][j]=min(st[i-1][j],st[i-1][j+(1<<i-1)]);
	for(int i=1;i<=q;++i){
		int n1,n2; std::vector<int> v; std::cin >> n1 >> n2;
		for(int i=1,x;i<=n1;++i) std::cin >> x,x = sam::r[x],vt::c1[x] = 1,v.push_back(x);
		for(int i=1,x;i<=n2;++i) std::cin >> x,x = sam::r[x],vt::c2[x] = 1,v.push_back(x);
		std::sort(v.begin(),v.end(),cmp),v.erase(std::unique(v.begin(),v.end()),v.end());
		for(int i=0,s=v.size();i+1<s;++i) v.push_back(lca(v[i],v[i+1]));
		std::sort(v.begin(),v.end(),cmp),v.erase(std::unique(v.begin(),v.end()),v.end());
		static int st[maxn]; int top = 1; st[top] = v[0]; vt::num = 0;
		for(int i=1;i<v.size();++i){
			while(top > 1 && dfn[v[i]] >= dfn[st[top]] + size[st[top]]) -- top;
			vt::adde(st[top],v[i]); st[++top] = v[i];
		}
		ans = 0,vt::dfs(v[0]),std::cout << ans << '
';
		for(int i=0;i<v.size();++i) vt::c1[v[i]] = vt::c2[v[i]] = 0;
	}
}

原文地址:https://www.cnblogs.com/skip1978/p/11409150.html