【BZOJ4012】开店(HNOI2015)-动态点分治+set

测试地址:开店
题目大意:给定一棵树,每个点有点权,每条边有长度,多次询问,每次询问给出三个数u,l,r,问点权在[l,r]内的所有点到点u的距离之和。
做法:本题需要用到动态点分治+set。
注意到每次只询问一端为定点的所有路径,这就提示我们使用动态点分治。我们每次查询时,从询问点开始在点分树上往上走,每走到一个点,就统计另一个端点在这个点的某棵子树中的所有路径的贡献,包含询问点的子树是不计算的(因为会算重),因为一个点度数最多为3,所以一个点的子树数量是常数级别的。
那么怎么计算这个贡献呢?我们需要求在点分树的某棵子树中,某个点权范围内的点数(要计算有多少条路径包含从询问点到子树根的祖先的一段),以及这些点到子树根的祖先的距离和,我们显然可以用一个set来存储这些信息的前缀和,这样我们就可以O(logn)求出这些东西来了。于是我们就得到了一个时间复杂度为O((n+m)log2n)的动态点分治做法,可以通过此题(虽然常数巨大)。
(据说还有树链剖分+主席树的做法……改天再去看看吧)
以下是本人代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN=150010;
int n,Q,x[MAXN],first[MAXN]={0},firstt[MAXN]={0},tot=0;
int fa[MAXN],dep[MAXN],siz[MAXN],mxson[MAXN],totsiz,q[MAXN];
ll A,dis[MAXN][20];
bool vis[MAXN]={0};

struct edge
{
    int v,next;
    ll w;
}e[MAXN<<1],t[MAXN<<1];
void insert(int a,int b,ll c) {e[++tot].v=b,e[tot].next=first[a],e[tot].w=c,first[a]=tot;}
void insertt(int a,int b) {t[++tot].v=b,t[tot].next=firstt[a],firstt[a]=tot;}

struct point
{
    int id;
    ll sum,siz;
    bool operator < (point a) const
    {
        return id<a.id;
    }
}nxt;

set<point>::iterator it;
struct Set
{
    set<point> a;
    ll sum,siz;
    void push(int id,ll sum,ll siz)
    {
        nxt.id=id,nxt.sum=sum,nxt.siz=siz;
        it=a.lower_bound(nxt);
        if (it!=a.end()&&(*it).id==nxt.id)
        {
            nxt.sum+=(*it).sum;
            nxt.siz+=(*it).siz;
            a.erase(it);
            a.insert(nxt);
        }
        else a.insert(nxt);
    }
    void calc(int l,int r,ll &anssum,ll &anssiz)
    {
        anssum=anssiz=0;
        nxt.id=l;
        it=a.lower_bound(nxt);
        if (it!=a.end())
        {
            anssum-=(*it).sum;
            anssiz-=(*it).siz;
        }
        else anssum-=sum,anssiz-=siz;
        nxt.id=r;
        it=a.upper_bound(nxt);
        if (it!=a.end())
        {
            anssum+=(*it).sum;
            anssiz+=(*it).siz;
        }
        else anssum+=sum,anssiz+=siz;
    }
}s[MAXN];

void init()
{
    scanf("%d%d%lld",&n,&Q,&A);
    for(int i=1;i<=n;i++)
        scanf("%d",&x[i]);
    for(int i=1;i<n;i++)
    {
        int a,b;
        ll c;
        scanf("%d%d%lld",&a,&b,&c);
        insert(a,b,c),insert(b,a,c);
    }
    int ans=0;
}

void dp(int v,int Fa)
{
    siz[v]=1;
    mxson[v]=0;
    q[++totsiz]=v;
    for(int i=first[v];i;i=e[i].next)
        if (e[i].v!=Fa&&!vis[e[i].v])
        {
            dp(e[i].v,v);
            mxson[v]=max(mxson[v],siz[e[i].v]);
            siz[v]+=siz[e[i].v];
        }
}

int find(int v)
{
    totsiz=0;
    dp(v,0);
    int ans=100000000,ansi;
    for(int i=1;i<=totsiz;i++)
        if (max(mxson[q[i]],totsiz-siz[q[i]])<ans)
        {
            ans=max(mxson[q[i]],totsiz-siz[q[i]]);
            ansi=q[i];
        }
    return ansi;
}

void put(int v,int tar,ll d,int Fa)
{
    dis[v][dep[v]-dep[tar]]=d;
    s[tar].push(x[v],d,1);
    for(int i=first[v];i;i=e[i].next)
        if (e[i].v!=Fa&&!vis[e[i].v]) put(e[i].v,tar,d+e[i].w,v);
}

int solve(int v,int Fa)
{
    v=find(v);
    vis[v]=1;
    fa[v]=Fa;
    dep[v]=dep[Fa]+1;
    for(int i=first[v];i;i=e[i].next)
        if (!vis[e[i].v])
        {
            int nx=solve(e[i].v,v);
            insertt(v,nx);
            put(e[i].v,nx,e[i].w,0);
        }

    for(int i=firstt[v];i;i=t[i].next)
    {
        it=s[t[i].v].a.begin();
        ll last=0,siz=0;
        while(it!=s[t[i].v].a.end())
        {
            s[t[i].v].sum+=(*it).sum;
            s[t[i].v].siz+=(*it).siz;
            nxt.id=(*it).id,nxt.sum=last,nxt.siz=siz;
            s[t[i].v].a.erase(it);
            s[t[i].v].a.insert(nxt);
            last=s[t[i].v].sum;
            siz=s[t[i].v].siz;
            it=s[t[i].v].a.upper_bound(nxt);
        }
    }

    vis[v]=0;
    return v;
}

ll calc(int v,ll L,ll R)
{
    int now=v,last=0,h=0;
    ll ans=0;
    while(now)
    {
        ll totsum,totsiz;
        for(int i=firstt[now];i;i=t[i].next)
            if (t[i].v!=last)
            {
                s[t[i].v].calc(L,R,totsum,totsiz);
                if (h>0) ans+=dis[v][h-1]*totsiz+totsum;
                else ans+=totsum;
            }
        if (x[now]>=L&&x[now]<=R&&h>0) ans+=dis[v][h-1];
        h++;
        last=now;
        now=fa[now];
    }
    return ans;
}

void work()
{
    tot=0;
    dep[0]=0;
    solve(1,0);
    ll lastans=0;
    for(int i=1;i<=Q;i++)
    {
        int u;
        ll a,b,L,R;
        scanf("%d%lld%lld",&u,&a,&b);
        L=min((a+lastans)%A,(b+lastans)%A);
        R=max((a+lastans)%A,(b+lastans)%A);
        printf("%lld
",lastans=calc(u,L,R));
    }
}

int main()
{
    init();
    work();

    return 0;
}
原文地址:https://www.cnblogs.com/Maxwei-wzj/p/9793370.html