【NOIP2016】天天爱跑步 题解(LCA+桶+树上差分)

题目链接

题目大意:给定一颗含有$n$个结点的树,每个结点有一个权值$w$。给定$m$条路径,如果一个点与路径的起点的距离恰好为$w$,那么$ans[i]++$。求所有结点的ans。

题目分析

暴力的做法当然是枚举条路径,然后玄学$dfs$,复杂度应该是$O(nm)$的。再根据约束条件可以拿到65pts。

正解

对于一条路径$(u,v)$,我们可以将其分成两段:$(u,lca(u,v))$和$(lca(u,v),v)$。

我们先来分析上行路段。上行路段的要求有3个:

1.$u$在以$i$为根的子树里面。

2.$lca(u,v)$在以$i$为根的子树外面。

3.$dep[u]=dep[i]+w[i]$

同理对于下行路段也有3个条件:

1.$v$在以$i$为根的子树里面。

2.$lca(u,v)$在以i为根的子树外面。

3.$dis[s,t]-dep[t]=w[i]-dep[i]$

这样我们可以枚举每个结点,即dfs整棵树,复杂度$O(n)$。

对于这道题,我们还需要用桶来统计贡献。具体操作方法:

b1:上行阶段的贡献值。

b2:下行阶段的贡献值。

void dfs2(int x)
{
    int t1=b1[w[x]+dep[x]], t2=b2[w[x]-dep[x]+maxn];//递归前的ans[x]
    for(int i=head[x]; i; i=edge[i].next)
    {
        int y=edge[i].to;
        if(y==fa[x][0]) continue;
        dfs2(y);//递归整棵树
    }
    b1[dep[x]]+=st[x];
    for(int i=head1[x]; i; i=edge1[i].next)//h1是用链式前向星存的每个点作为终点的路径集合
    {
        int y=edge1[i].to;
        b2[dis[y]-dep[t[y]]+maxn]++;//根据前面的等式。方法类似雨后的尾巴
    }
    ans[x]+=b1[w[x]+dep[x]]-t1+b2[w[x]-dep[x]+maxn]-t2;//加上差值
    ///////未完待续////////
}

我们不能忘记一点:树是递归进行操作的。

什么意思?还记得之前的约束条件吗?统计答案时$lca(u,v)$必然不能存在于子树中。所以当点i作为lca(u,v)时,统计完答案后要减去$(u,v)$对i的贡献。因为$(u,v)$的贡献对于i的祖先是不合法的。

for(int i=head2[x]; i; i=edge2[i].next)//h2是链式前向星存的每个点作为lca的路径集合
    {
        int y=edge2[i].to;
        b1[dep[s[y]]]--;
        b2[dis[y]-dep[t[y]]+maxn]--;
    }

主函数主要代码:

for (int i=1;i<=m;i++)
    {
        s[i]=read(),t[i]=read();
        int ll=lca(s[i],t[i]);
        dis[i]=dep[s[i]]+dep[t[i]]-2*dep[ll];
        st[s[i]]++;//统计以此点作为起点的路径条数
        add1(t[i],i);//
        add2(ll,i);
        if (dep[ll]+w[ll]==dep[s[i]]) ans[ll]--;//防止重复统计:当路径起点或终点恰好为两点LCA时且LCA处可以观察到运动员
    }

注意数组下标的平移。时间复杂度$O(nlogn)$。

完整代码:

#include<bits/stdc++.h>
#define int long long 
using namespace std;
const int maxn=400005;
int n,m;
int fa[maxn][21],dep[maxn],b1[maxn*2],b2[maxn*2];
int dis[maxn],ans[maxn],s[maxn],t[maxn],st[maxn],w[maxn];
int head[maxn*2],cnt,head1[maxn*2],cnt1,head2[maxn*2],cnt2;
struct node
{
    int next,to;
}edge[maxn*2],edge1[maxn*2],edge2[maxn*2];
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
    while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
void add(int x, int y)
{
    edge[++cnt].to=y;
    edge[cnt].next=head[x];
    head[x]=cnt;
}
void add1(int x, int y)
{
    edge1[++cnt1].to=y;
    edge1[cnt1].next=head1[x];
    head1[x]=cnt1;
}
void add2(int x, int y)
{
    edge2[++cnt2].to=y;
    edge2[cnt2].next=head2[x];
    head2[x]=cnt2;
}
inline void dfs1(int now)
{
    for (int i=1;(1<<i)<=dep[now];i++)
        fa[now][i]=fa[fa[now][i-1]][i-1];
    for (int i=head[now];i;i=edge[i].next)
    {
        int to=edge[i].to;
        if (to==fa[now][0]) continue;
        fa[to][0]=now;
        dep[to]=dep[now]+1;
        dfs1(to);
    }
}
inline int lca(int x,int y)
{
    if (x==y) return x;
    if (dep[x]<dep[y]) swap(x,y);
    int t=log(dep[x]-dep[y])/log(2);
    for (int i=t;i>=0;i--)
    {
        if (dep[fa[x][i]]>=dep[y])
            x=fa[x][i];
        if (x==y) return x;
    }
    t=log(dep[x])/log(2);
    for (int i=t;i>=0;i--)
    {
        if (fa[x][i]!=fa[y][i])
            x=fa[x][i],y=fa[y][i];
    }
    return fa[x][0];
}
void dfs2(int x)
{
    int t1=b1[w[x]+dep[x]], t2=b2[w[x]-dep[x]+maxn];
    for(int i=head[x]; i; i=edge[i].next)
    {
        int y=edge[i].to;
        if(y==fa[x][0]) continue;
        dfs2(y);
    }
    b1[dep[x]]+=st[x];
    for(int i=head1[x]; i; i=edge1[i].next)
    {
        int y=edge1[i].to;
        b2[dis[y]-dep[t[y]]+maxn]++;
    }
    ans[x]+=b1[w[x]+dep[x]]-t1+b2[w[x]-dep[x]+maxn]-t2;
    for(int i=head2[x]; i; i=edge2[i].next)
    {
        int y=edge2[i].to;
        b1[dep[s[y]]]--;
        b2[dis[y]-dep[t[y]]+maxn]--;
    }
}

signed main()
{
    n=read(),m=read();
    for (int i=1;i<n;i++)
    {
        int x=read(),y=read();
        add(x,y);add(y,x);
    }
    dep[1]=1;fa[1][0]=1;
    dfs1(1);
    for (int i=1;i<=n;i++) w[i]=read();
    for (int i=1;i<=m;i++)
    {
        s[i]=read(),t[i]=read();
        int ll=lca(s[i],t[i]);
        dis[i]=dep[s[i]]+dep[t[i]]-2*dep[ll];
        st[s[i]]++;
        add1(t[i],i);
        add2(ll,i);
        if (dep[ll]+w[ll]==dep[s[i]]) ans[ll]--;
    }
    dfs2(1);
    for (int i=1;i<=n;i++) printf("%lld ",ans[i]);
    return 0;
}
原文地址:https://www.cnblogs.com/Invictus-Ocean/p/13187270.html