[ZJOI2019] 语言

一、题目

点此看题

二、解法

( t ZJOI) 真有意思!

首先思考一下部分分吧,对于 (20) 分的 ( t subtask) 满足原树是一条链,我们可以维护每个点为左端点最远覆盖到的右端点,用线段树求区间最大值然后把所有位置的贡献求个和即可,时间复杂度 (O(nlog n))

能不能把这个做法搬到树上呢?那么我们思考单个点的贡献,只考虑经过了这个点的路径,感性一点就是把所有极远点连接起来,那么让所有路径端点联通的最小生成树大小就是答案。

肯定不能真的去求这个生成树,能否巧妙地计算出生成树大小?考虑已有的生成树点集 (S),我们加入一个点 (u),那么找出最深的 (lca(u,v),vin S)(u) 的贡献的 (dep[u]-dep[lca]),你发现这个东西很像虚树,你把所有点按 ( t dfn) 排序之后 (lca=lca(u,last)),其中 (last) 表示上一个加入的点。

注意第一个加入点的贡献并不是它的深度,而应该减去最上面联通点的深度,所以可以先记贡献为 (dep[x]),最后再减去 (dep[lca]),这里的 (lca) 是指的所有点的 (lca)

那么还剩下一个问题,就是如何取出经过了这个点的路径,可以考虑树上差分(反正只需要计算一次答案),在 (u,v) 处加入路径 ((u,v)),在 (lca(u,v),fa[lca(u,v)]) 处删除路径 ((u,v))

加入点的问题可以用线段树维护,线段树的下标需要按 ( t dfn) 排序,上传的时候减去左儿子最右边的点和右儿子最左边的点的 (lca) 的深度即可,然后线段树是支持标记合并的,所以写个线段树合并就行了,时间复杂度 (O(nlog n))

#include <cstdio>
#include <vector>
#include <iostream>
using namespace std;
const int M = 200005;
const int N = 100*M;
#define ll long long
int read()
{
	int x=0,f=1;char c;
	while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
	while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
	return x*f;
}
int n,m,k,tot,Ind,f[M],dfn[M],id[M],num[M],dep[M],dp[M][20];
int cnt,lg[M],p[M],rt[M],sum[N],tim[N],li[N],ri[N],ls[N],rs[N];
vector<int> g[M];ll ans;
struct edge
{
    int v,next;
    edge(int V=0,int N=0) : v(V) , next(N) {}
}e[2*M];
void dfs(int u,int fa)
{
    dfn[u]=++Ind;id[Ind]=u;
    dep[u]=dep[fa]+1;p[u]=fa;
    num[u]=++k;dp[k][0]=u;
    for(int i=f[u];i;i=e[i].next)
    {
        int v=e[i].v;
        if(v==fa) continue;
        dfs(v,u);
        dp[++k][0]=u;
    }
}
int Min(int x,int y)
{
    return dep[x]<dep[y]?x:y;
}
int lca(int l,int r)
{
    l=num[l];r=num[r];
    if(l>r) swap(l,r);
    int k=lg[r-l+1];
    return Min(dp[l][k],dp[r-(1<<k)+1][k]);
}
void up(int x)
{
    sum[x]=sum[ls[x]]+sum[rs[x]]-dep[lca(ri[ls[x]],li[rs[x]])];
    li[x]=li[ls[x]]?li[ls[x]]:li[rs[x]];
    ri[x]=ri[rs[x]]?ri[rs[x]]:ri[ls[x]];
}
int merge(int x,int y,int l,int r)
{
    if(!x || !y) return x|y;
    if(l==r)
    {
        tim[x]+=tim[y];
        li[x]=ri[x]=tim[x]?id[l]:0;
        sum[x]=tim[x]?dep[id[l]]:0;
        return x;
    }
    int mid=(l+r)>>1;
    ls[x]=merge(ls[x],ls[y],l,mid);
    rs[x]=merge(rs[x],rs[y],mid+1,r);
    up(x);
    return x;
}
void ins(int &x,int l,int r,int t,int f)
{
    if(!x) x=++cnt;
    if(l==r)
    {
        tim[x]+=f;
        li[x]=ri[x]=tim[x]?id[t]:0;
        sum[x]=tim[x]?dep[id[t]]:0;
        return ;
    }
    int mid=(l+r)>>1;
    if(t<=mid) ins(ls[x],l,mid,t,f);
    else ins(rs[x],mid+1,r,t,f);
    up(x);
}
void work(int u)
{
    for(int i=f[u];i;i=e[i].next)
    {
        int v=e[i].v;
        if(v==p[u]) continue;
        work(v);
        rt[u]=merge(rt[u],rt[v],1,n);
    }
    for(int i=0;i<g[u].size();i++)
        ins(rt[u],1,n,dfn[g[u][i]],-1);
    ans+=sum[rt[u]]-dep[lca(li[rt[u]],ri[rt[u]])];
}
signed main()
{
    n=read();m=read();
    for(int i=1;i<n;i++)
    {
        int u=read(),v=read();
        e[++tot]=edge(v,f[u]),f[u]=tot;
        e[++tot]=edge(u,f[v]),f[v]=tot;
    }
    dfs(1,0);
    for(int i=2;i<=k;i++)
        lg[i]=lg[i>>1]+1;
    for(int j=1;(1<<j)<=k;j++)
        for(int i=1;i+(1<<j)-1<=k;i++)
            dp[i][j]=Min(dp[i][j-1],dp[i+(1<<j-1)][j-1]);
    for(int i=1;i<=m;i++)
    {
        int x=read(),y=read(),t=lca(x,y);
        ins(rt[x],1,n,dfn[x],1);ins(rt[x],1,n,dfn[y],1);
        ins(rt[y],1,n,dfn[x],1);ins(rt[y],1,n,dfn[y],1);
        g[t].push_back(x);g[p[t]].push_back(x);
        g[t].push_back(y);g[p[t]].push_back(y);
    }
    work(1);
    printf("%lld
",ans/2);
}
原文地址:https://www.cnblogs.com/C202044zxy/p/14725359.html