点分治初步

  点分治是一种常用于处理树上点对关系的分治算法。

一、算法介绍

  提到点分治,我们先来看一道例题:洛谷P3806 【模板】点分治1

  题意:多组询问,边有边权,询问树上是否存在距离为$ k $的点对。$ n leq 10^4, k leq 10^7 $

  我们显然有一种暴力算法:对于每个询问,枚举每个点对判断距离是否等于给定的$ k $,复杂度$ O(mn^2) $,但这样的复杂显然太高了,我们需要更快的算法。

  我们发现如果钦定了树根,那么$ dist(x,y)=dep[x]+dep[y]-2 imes dep[lca] $,于是我们可以尝试枚举$ lca $,然后搜索子树中的每个节点,对于遍历到的当前结点$ now $,寻找是否存在一个已遍历节点$ x $使$ dep[x]=k+2 imes dep[lca]-dep[y] $,这个我们可以用一个桶存下已遍历结点到$ lca $的距离。然而我们发现状况还是没什么变化,复杂度依然是$ O(mn^2) $的。

  然而,我们发现在计算$ lca $的贡献的时候,我们相当于一次统计了经过$ lca $的所有路径,于是我们接下来只需对没有经过$ lca$的路径,也就是$ lca $的每个子树分别统计,这就是所谓的“分治”,将问题分解成几个子问题分别处理。

  但是,这和前面的算法有什么区别?这样的算法在某些数据下还是不够优秀,比如当树是一条链的情况下,每次$ O(n) $处理完当前节点后,然后往子节点走,这样时间复杂度还是会被卡到$ O(n^2) $。虽然如此,我们发现造成时间复杂度退化的原因是,我们在处理每个子树时钦定的根节点不够优秀。这使我们就会想到无根树上有一个性质极其优秀的点:重心。重心满足以它为根,它的每个子树的大小都不会超过$ frac{n}{2} $的性质,如果我们每次开始处理一块子树,都以这块子树的重心为根开始处理,每次能使问题的规模下降以半,最多分治$ O(log{n}) $层。在上面的例题中,总复杂度为$ O(mnlog{n}) $

  现在我们总结一下点分治的基本思路:

    1、先计算出当前统计的这一块树的答案。(注意:计算答案时需保证点对是在$ lca $不同的子树中,这样才能保证路径经过$ lca $,如例题中,我们处理完一个子树后才把该子树的信息加入桶中)

    2、找出这块树的重心。

    3、把重心删除,递归处理该树断开成的几棵子树。

  核心代码:

void divide(int now)
{
    solve(now);//计算以now为根的这棵树的答案
    vis[now]=1;//删除节点now
    for(int i=son of now){
        int nxt=getroot(i);//计算儿子i所在的这棵子树的重心
        divide(nxt);//递归分治处理
    }
}

  以下是例题的完整代码:

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<ctime>
#include<algorithm>
#define ll long long
#define maxn 10010
inline ll read()
{
    ll x=0; char c=getchar(),f=1;
    for(;c<'0'||'9'<c;c=getchar())if(c=='-')f=-1;
    for(;'0'<=c&&c<='9';c=getchar())x=x*10+c-'0';
    return x*f;
}
inline void write(ll x)
{
    char buf[20],len; len=0;
    if(x<0)putchar('-'),x=-x;
    for(;x;x/=10)buf[len++]=x%10+'0';
    if(!len)putchar('0');
    else while(len)putchar(buf[--len]);
}
inline void writesp(ll x){write(x); putchar(' ');}
inline void writeln(ll x){write(x); putchar('
');}
struct edge{
    int to,nxt,d;
}e[2*maxn];
int fir[maxn],dist[maxn],size[maxn],vis[maxn];
int id[maxn];
int mark[10000010];
int q[110],ok[110];
int n,m,tot;
void add_edge(int x,int y,int z){e[tot].to=y; e[tot].d=z; e[tot].nxt=fir[x]; fir[x]=tot++;}
void search(int now,int fa)
{
    id[++tot]=now; size[now]=1;
    for(int i=fir[now];~i;i=e[i].nxt)
        if(e[i].to!=fa&&!vis[e[i].to]){
            dist[e[i].to]=dist[now]+e[i].d;
            search(e[i].to,now);
            size[now]+=size[e[i].to];
        }
}
void solve(int now)
{
    tot=1; id[1]=now; dist[now]=0; mark[0]=1;
    int last=1;
    for(int i=fir[now];~i;i=e[i].nxt)
        if(!vis[e[i].to]){
            dist[e[i].to]=e[i].d;
            search(e[i].to,now);
            for(int j=1;j<=m;j++){
                if(ok[j])continue;
                for(int k=last+1;k<=tot;k++)
                    if(dist[id[k]]<=q[j])ok[j]|=mark[q[j]-dist[id[k]]];
            }
            for(int j=last+1;j<=tot;j++)
                if(dist[id[j]]<=10000000)mark[dist[id[j]]]=1;
            last=tot;
        }
    for(int i=1;i<=tot;i++)
        if(dist[id[i]]<=10000000)mark[dist[id[i]]]=0;
}
int getroot(int now,int fa,int S)
{
    int mx=0;
    size[now]=1;
    for(int i=fir[now];~i;i=e[i].nxt)
        if(e[i].to!=fa&&!vis[e[i].to]){
            int t=getroot(e[i].to,now,S);
            if(t)return t;
            size[now]+=size[e[i].to];
            if(size[e[i].to]>mx)mx=size[e[i].to];
        }
    if(S-size[now]>mx)mx=S-size[now];
    if(mx*2<=S)return now;
    else return 0;
}
void divide(int now)
{
    solve(now);
    vis[now]=1;
    for(int i=fir[now];~i;i=e[i].nxt)
        if(!vis[e[i].to]){
            int rt=getroot(e[i].to,now,size[e[i].to]);
            divide(rt);
        }
}
int main()
{
    n=read(); m=read();
    memset(fir,255,sizeof(fir)); tot=0;
    for(int i=1;i<n;i++){
        int x=read(),y=read(),z=read();
        add_edge(x,y,z); add_edge(y,x,z);
    }
    for(int i=1;i<=m;i++)
        q[i]=read();
    int rt=getroot(1,-1,n);
    divide(rt);
    for(int i=1;i<=m;i++)
        puts(ok[i]?"AYE":"NAY");
    return 0;
}
luoguP3806

二、练习

  1、洛谷P4149 [IOI2011]Race

    题意:给一棵有边权的树求一条长度最短的距离为$ K $的路径。

    显然是一道点分治裸题,在统计每个$ lca $的答案时用一个桶记录到$ lca $距离一定的结点的最小深度,然后其他做法与上题基本相同。

    代码:

// luogu-judger-enable-o2
#include<cstdio>
#include<cstring>
#include<cmath>
#include<cstdlib>
#include<ctime>
#include<algorithm>
#define ll long long
#define inf 0x3f3f3f3f
#define maxn 200010
inline ll read()
{
    ll x=0; char c=getchar(),f=1;
    for(;c<'0'||'9'<c;c=getchar())if(c=='-')f=-1;
    for(;'0'<=c&&c<='9';c=getchar())x=x*10+c-'0';
    return x*f;
}
inline void write(ll x)
{
    static char buf[20],len; len=0;
    if(x<0)x=-x,putchar('-');
    for(;x;x/=10)buf[len++]=x%10+'0';
    if(!len)putchar('0');
    else while(len)putchar(buf[--len]);
}
inline void writesp(ll x){write(x); putchar(' ');}
inline void writeln(ll x){write(x); putchar('
');}
struct edge{
    int to,nxt,d;
}e[2*maxn];
int fir[maxn],size[maxn],dep[maxn],vis[maxn];
ll dist[maxn];
int id[maxn],mn[1000010];
int n,m,tot,ans;
void add_edge(int x,int y,int z){e[tot].to=y; e[tot].d=z; e[tot].nxt=fir[x]; fir[x]=tot++;}
void search(int now,int fa)
{
    id[++tot]=now; size[now]=1;
    for(int i=fir[now];~i;i=e[i].nxt)
        if(e[i].to!=fa&&!vis[e[i].to]){
            dist[e[i].to]=dist[now]+e[i].d; dep[e[i].to]=dep[now]+1;
            search(e[i].to,now);
            size[now]+=size[e[i].to];
        }
}
void solve(int now)
{
    mn[0]=0; tot=0;
    int last=1;
    for(int i=fir[now];~i;i=e[i].nxt)
        if(!vis[e[i].to]){
            dist[e[i].to]=e[i].d; dep[e[i].to]=1;
            search(e[i].to,now);
            for(int j=last;j<=tot;j++)
                if(dist[id[j]]<=m)ans=std::min(ans,dep[id[j]]+mn[m-dist[id[j]]]);
            for(int j=last;j<=tot;j++)
                if(dist[id[j]]<=m)mn[dist[id[j]]]=std::min(mn[dist[id[j]]],dep[id[j]]);
            last=tot+1;
        }
    for(int i=1;i<=tot;i++)
        if(dist[id[i]]<=m)mn[dist[id[i]]]=inf;
    mn[0]=inf;
}
int getroot(int now,int fa,int S)
{
//    printf("%d %d %d ****
",now,fa,S);
    size[now]=1;
    int mx=0;
    for(int i=fir[now];~i;i=e[i].nxt)
        if(e[i].to!=fa&&!vis[e[i].to]){
            int tmp=getroot(e[i].to,now,S);
            if(~tmp)return tmp;
            size[now]+=size[e[i].to];
            if(size[e[i].to]>mx)mx=size[e[i].to];
        }
    if(S-size[now]>mx)mx=S-size[now];
    if(mx<<1<=S)return now;
    else return -1;
}
void divide(int now)
{
//    writeln(now);
//    system("pause");
    solve(now);
    vis[now]=1;
    for(int i=fir[now];~i;i=e[i].nxt)
        if(!vis[e[i].to]){
            int nxt=getroot(e[i].to,now,size[e[i].to]);
            divide(nxt);
        }
}
int main()
{
    n=read(); m=read();
    memset(fir,255,sizeof(fir)); tot=0;
    for(int i=1;i<n;i++){
        int x=read(),y=read(),z=read();
        add_edge(x,y,z); add_edge(y,x,z);
    }
    memset(mn,0x3f,sizeof(mn)); ans=inf;
    int init=getroot(0,-1,n);
    divide(init);
    writeln(ans!=inf?ans:-1);
    return 0;
}
luoguP4149
原文地址:https://www.cnblogs.com/quzhizhou/p/10651030.html