【树】点分治学习笔记

不做笔记的后果是我完全忘记了我在5个月前就学过点分治(去洛谷做题才发现的).....

点分治大概是用于树上路径的求解。

点分治分4步走:

1,对当前子树找重心//固定

void getroot(int u,int fa)
{
    siz[u]=1;ms[u]=0;
    for(int i=0;i<p[u].size();i++)
    {
        int v=p[u][i].to;
        if(v==fa||vis[v])continue;
        getroot(v,u);
        siz[u]+=siz[v];
        ms[u]=max(ms[u],siz[v]);
    }
    ms[u]=max(ms[u],Tsiz-siz[u]);
    if(ms[u]<ms[root])root=u;
}

2,把树上距离(边权)存进临时数组//根据实际需要稍作修改 。。。比如临时数组需要的不是树上距离

void getdis(int u,int fa,int d)
{
    dis[++cnt]=d;
    for(int i=0;i<p[u].size();i++)
    {
        int v=p[u][i].v,w=p[u][i].w;
        if(vis[v]||v==fa) continue;
        getdis(v,u,d+w);
    }
} 

3,用于divide函数分层解决的solve函数//不同题都不大相同

  ans是局部变量。

int solve(int u,int d)
{
    int ans=0;
    cnt=0;
    getdis(u,0,d);//此时临时数组dis的[1,cnt]存储数据 
    /*
    之后统计 C(n,2)个 两个dis的和就是树上(真or/假)通过根的两点的距离
    不能真的C(n,2), n^2会超时哭的。
    需要自己各种优化 
    */
    return ans; 
}

4,divide函数,找重心,递归,分层求解

  Tsiz最初等于n。ans是全局变量。

void divide(int u)
{
    vis[u]=1;
    ans+=solve(u,0);
    for(int i=0;i<p[u].size();i++)
    {
        int v=p[u][i].to,w=p[u][i].w;
        if(vis[v])continue;
        ans-=solve(v,w);
        Tsiz=siz[v];
        root=0;getroot(v,u);
        divide(root);
    }
    return ;
}

模板例题

洛谷P4178 Tree 找树上距离小于或等于固定值K的所有点对数

#include<bits/stdc++.h>
#define debug printf("!");
using namespace std;
typedef long long ll;
const int maxn=1e5+50;
const int inf=0x3f3f3f3f;

int K,ans;

struct P{
    int v,w;
};
vector<P>p[maxn];

int cnt,dis[maxn],siz[maxn],ms[maxn],vis[maxn],Tsiz,root;

void getroot(int u,int fa)
{
    siz[u]=1;ms[u]=0;
    for(int i=0;i<p[u].size();i++)
    {
        int v=p[u][i].v;
        if(v==fa||vis[v])continue;
        getroot(v,u);
        siz[u]+=siz[v];
        ms[u]=max(ms[u],siz[v]);
    }
    ms[u]=max(ms[u],Tsiz-ms[u]);
    if(ms[root]>ms[u])root=u;
}
void getdis(int u,int fa,int d)
{
    dis[++cnt]=d;
    for(int i=0;i<p[u].size();i++)
    {
        int v=p[u][i].v,w=p[u][i].w;
        if(vis[v]||v==fa) continue;
        getdis(v,u,d+w);
    }
}

int solve(int u,int d)
{
    int ans=0;
    cnt=0;
    getdis(u,0,d);
    sort(dis+1,dis+1+cnt);
    int l=1,r=cnt;
    for(l=1;l<r;l++)
    {
        while(dis[l]+dis[r]>K&&r>l)r--;
        if(r==l)break;
        ans+=r-l;
    }
    return ans; 
}

void divide(int u)
{
    vis[u]=1;
    ans+=solve(u,0);
    for(int i=0;i<p[u].size();i++)
    {
        int v=p[u][i].v,w=p[u][i].w;
        if(vis[v])continue;
        ans-=solve(v,w);
        Tsiz=siz[v];
        root=0;getroot(v,u);
        divide(root);
    }
}
int main()
{
    int n,i,u,v,w;
    scanf("%d",&n);
    for(i=1;i<n;i++)
    {
        scanf("%d%d%d",&u,&v,&w);
        p[u].push_back(P{v,w});
        p[v].push_back(P{u,w});
    }
    scanf("%d",&K);
    root=0;ms[0]=Tsiz=n;getroot(1,0);divide(root);
    printf("%d
",ans);
}
View Code

hncpc 2019找树上距离为2019的倍数的所有点对数

#include<bits/stdc++.h>
#define debug printf("!");
using namespace std;
typedef long long ll;
const int inf=0x3f3f3f3f;
const int maxn=1e6+50;
 
struct P{
    int to,w;
};
vector<P>p[maxn];
 
int ans;
 
int siz[maxn],ms[maxn],vis[maxn],root,cnt,Tsiz;
 
void getroot(int u,int fa)
{
    siz[u]=1;ms[u]=0;
    for(int i=0;i<p[u].size();i++)
    {
        int v=p[u][i].to;
        if(v==fa||vis[v])continue;
        getroot(v,u);
        siz[u]+=siz[v];
        ms[u]=max(ms[u],siz[v]);
    }
    ms[u]=max(ms[u],Tsiz-siz[u]);
    if(ms[u]<ms[root])root=u;
}
 
int y[maxn],ny[maxn];
void getdis(int u,int fa)
{
    ny[y[u]]++;
    for(int i=0;i<p[u].size();i++)
    {
        int v=p[u][i].to,w=p[u][i].w;
        if(vis[v]||v==fa)continue;
        y[v]=(y[u]+w)%2019;
        getdis(v,u);
    }
}
int solve(int u,int w)
{
    for(int i=0;i<=2019;i++)ny[i]=0;
    y[u]=w%2019;
    getdis(u,0);
    int res=0;
    res=ny[0]*(ny[0]-1)/2;
    for(int i=1;i<=1009;i++)res+=ny[i]*ny[2019-i];
    return res;
}
void divide(int u)
{
    vis[u]=1;
    ans+=solve(u,0);
    for(int i=0;i<p[u].size();i++)
    {
        int v=p[u][i].to,w=p[u][i].w;
        if(vis[v])continue;
        ans-=solve(v,w);
        Tsiz=siz[v];
        root=0;getroot(v,u);
        divide(root);
    }
    return ;
}
 
int main()
{
    int n,i,j,k,u,v,w;
    while(~scanf("%d",&n))
    {
        ans=0;
        for(i=1;i<=n;i++)
        {
            p[i].clear();siz[i]=0;ms[i]=0;vis[i]=0;
        }
        for(i=1;i<n;i++)
        {
            scanf("%d%d%d",&u,&v,&w);
            p[u].push_back(P{v,w});
            p[v].push_back(P{u,w});
        }
        root=0;ms[0]=Tsiz=n;getroot(1,0);
        divide(root);
        printf("%d
",ans);
    }
}
View Code
原文地址:https://www.cnblogs.com/kkkek/p/11616859.html