洛谷P1600 天天爱跑步(线段树合并)

小c同学认为跑步非常有趣,于是决定制作一款叫做《天天爱跑步》的游戏。《天天爱跑步》是一个养成类游戏,需要玩家每天按时上线,完成打卡任务。

这个游戏的地图可以看作一一棵包含 nn个结点和 n-1n−1条边的树, 每条边连接两个结点,且任意两个结点存在一条路径互相可达。树上结点编号为从11到nn的连续正整数。

现在有mm个玩家,第ii个玩家的起点为 S_iS
i
​ ,终点为 T_iT
i
​ 。每天打卡任务开始时,所有玩家在第00秒同时从自己的起点出发, 以每秒跑一条边的速度, 不间断地沿着最短路径向着自己的终点跑去, 跑到终点后该玩家就算完成了打卡任务。 (由于地图是一棵树, 所以每个人的路径是唯一的)

小c想知道游戏的活跃度, 所以在每个结点上都放置了一个观察员。 在结点jj的观察员会选择在第W_jW
j
​ 秒观察玩家, 一个玩家能被这个观察员观察到当且仅当该玩家在第W_jW
j
​ 秒也理到达了结点 jj 。 小C想知道每个观察员会观察到多少人?

注意: 我们认为一个玩家到达自己的终点后该玩家就会结束游戏, 他不能等待一 段时间后再被观察员观察到。 即对于把结点jj作为终点的玩家: 若他在第W_jW
j
​ 秒前到达终点,则在结点jj的观察员不能观察到该玩家;若他正好在第W_jW
j
​ 秒到达终点,则在结点jj的观察员可以观察到这个玩家。

输入输出格式
输入格式:
第一行有两个整数nn和mm 。其中nn代表树的结点数量, 同时也是观察员的数量, mm代表玩家的数量。

接下来 n- 1n−1行每行两个整数uu和 vv,表示结点 uu到结点 vv有一条边。

接下来一行 nn个整数,其中第jj个整数为W_jW
j
​ , 表示结点jj出现观察员的时间。

接下来 mm行,每行两个整数S_iS
i
​ ,和T_iT
i
​ ,表示一个玩家的起点和终点。

对于所有的数据,保证1leq S_i,T_ileq n, 0leq W_jleq n1≤S
i
​ ,T
i
​ ≤n,0≤W
j
​ ≤n 。

输出格式:
输出1行 nn个整数,第jj个整数表示结点jj的观察员可以观察到多少人。

输入输出样例
输入样例#1:
6 3
2 3
1 2
1 4
4 5
4 6
0 2 5 1 2 3
1 5
1 3
2 6
输出样例#1:
2 0 0 1 1 1
输入样例#2:
5 3
1 2
2 3
2 4
1 5
0 1 0 3 0
3 1
1 4
5 5
输出样例#2:
1 2 1 0 1

题解:号称提高组最难一题,其实难度还行

考虑把一个人的路径拆成起点到lca和lca到终点两段,差分一下用线段树维护

具体的操作是对起点插入deep起点,终点插入2*deep[lca]-deep起点,相当于把起点沿lca翻上去。

然后线段树合并一波就搞定了

查询的是每个点deep+wj和deep-wj距离的点有几个

其实线段树合并是大材小用了,如果对每个点查询li-ri时间之间有多少人经过显然才更妙

代码如下:

#include<bits/stdc++.h>
#define lson tr[now].l
#define rson tr[now].r
using namespace std;

struct tree
{
    int l,r,sum;
} tr[20000010];

vector<int> g[300010];
vector<int> op1[300010],op2[300010];
int n,m,ans[300010],q[300010],rt[300010],deep[300010],fa[300010][20],cnt;
int N=600000;

int dfs(int now,int f,int dep)
{
    deep[now]=dep;
    fa[now][0]=f;
    rt[now]=now;
    ++cnt;
    for(int i=1; i<=19; i++)
    {
        fa[now][i]=fa[fa[now][i-1]][i-1];
    }
    for(int i=0; i<g[now].size(); i++)
    {
        if(g[now][i]==f) continue;
        dfs(g[now][i],now,dep+1);
    }
}

int lca(int x,int y)
{
    if(deep[x]<deep[y]) swap(x,y);
    for(int i=19; i>=0; i--)
    {
        if(deep[fa[x][i]]>=deep[y]) x=fa[x][i];
    }
    if(x==y) return x;
    for(int i=19; i>=0; i--)
    {
        if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
    }
    return fa[x][0];
}

int push_up(int now)
{
    tr[now].sum=tr[lson].sum+tr[rson].sum;
}

int insert(int &now,int l,int r,int pos,int val)
{
    if(!now) now=++cnt;
    if(l==r)
    {
        tr[now].sum+=val;
        return 0;
    }
    int mid=(l+r)>>1;
    if(pos<=mid)
    {
        insert(lson,l,mid,pos,val);
    }
    else
    {
        insert(rson,mid+1,r,pos,val);
    }
    push_up(now);
}

int query(int now,int l,int r,int pos)
{
    if(l==r) return tr[now].sum;
    int mid=(l+r)>>1;
    if(pos<=mid) return query(lson,l,mid,pos);
    else return query(rson,mid+1,r,pos);
}

int merge(int a,int b,int l,int r)
{
    if(!b) return a;
    if(!a) return b;
    if(l==r)
    {
        tr[a].sum+=tr[b].sum;
        return a;
    }
    int mid=(l+r)>>1;
    tr[a].l=merge(tr[a].l,tr[b].l,l,mid);
    tr[a].r=merge(tr[a].r,tr[b].r,mid+1,r);
    push_up(a);
    return a;
}

int solve(int now,int f)
{
    for(int i=0; i<op1[now].size(); i++)
    {
        insert(rt[now],0,N,op1[now][i],1);
    }
    for(int i=0; i<op2[now].size(); i++)
    {
        insert(rt[now],0,N,op2[now][i],-1);
    }
    for(int i=0; i<g[now].size(); i++)
    {
        if(g[now][i]==f) continue;
        solve(g[now][i],now);
        merge(rt[now],rt[g[now][i]],0,N);
    }
    if(deep[now]+n-q[now]>=0) ans[now]+=query(rt[now],0,N,deep[now]+n-q[now]);
    if(deep[now]+n+q[now]<=N&&q[now]!=0) ans[now]+=query(rt[now],0,N,deep[now]+n+q[now]);
}

int main()
{
    int from,to;
    scanf("%d%d",&n,&m);
    for(int i=1; i<n; i++)
    {
        scanf("%d%d",&from,&to);
        g[from].push_back(to);
        g[to].push_back(from);
    }
    for(int i=1; i<=n; i++) scanf("%d",&q[i]);
    dfs(1,0,1);
    for(int i=1; i<=m; i++)
    {
        scanf("%d%d",&from,&to);
        int anc=lca(from,to);
        op1[from].push_back(deep[from]+n);
        op1[to].push_back(n-(deep[from]-deep[anc])+deep[anc]);
        op2[anc].push_back(deep[from]+n);
        op2[fa[anc][0]].push_back(n-(deep[from]-deep[anc])+deep[anc]);
    }
    solve(1,0);
    for(int i=1;i<=n;i++)
    {
        printf("%d ",ans[i]);
    }
}
原文地址:https://www.cnblogs.com/stxy-ferryman/p/9800654.html