点分治

啥是点分治?

点分治一般来说用于解决大规模的树上路径问题。比如说最经典的一道题,给定一棵树,计算一共有多少点对满足之间距离<=k。

这种题一般数据范围在10^4~10^5,直接暴力求是n^2的肯定会超时。

那怎么办?我们考虑分治。

先说一下点分治的基本思想,就是对于每一棵树,先找到这棵树的重心。

啥叫重心?重心就是一棵树中最大的子树节点最少的那个点。求树的重心就是直接暴力DFS,但是毕竟人家是O(n)的。

先看一眼代码。

void getroot(int x,int fa)
{
    size[x] = 1,maxs[x] = 0;
    for(int i = head[x];i;i = e[i].next)
    {
        int t = e[i].to;
        if(t == fa || vis[t]) continue;
        getroot(t,x);
        size[x] += size[t];
        maxs[x] = max(maxs[x],size[t]);
    }
    maxs[x] = max(maxs[x],sum - size[x]);
    if(maxs[x] < maxs[root]) root = x;
}

其中size记录节点子树大小,maxs记录最大子树节点大小。

最后一步就是因为还要计算自己的父亲和父亲之上的一些树,所以要进行更新。

这样我们就成功的O(n)求出了重心。

好接着上面的说,我们求出重心之后,首先统计这棵子树之内的所有答案,之后再分别递归到重心分割开的所有子树里面去统计答案。

怎么统计呢?首先我们对于每棵子树,我们从根开始向下进行dfs,更新到达每个点需要的距离,并且把出现过的点的距离全部加入当前统计数组里面。之后把统计数组排序,从两头开始找,只要当前两点之间的距离小于等于k,那么就直接加上r-l个答案,直到l>r为止。

不过这里有一些问题要注意,就是如果直接这么统计会出问题,因为我们只计算那些经过重心的道路,而这种计算方法它会在本次计算中重复计算一些子树中的情况,这样再向下递归的时候答案就会算重。所以对于每次计算,我们要再减去所有子树中的可能情况。之后向下递归求子树重心,继续求解。

这就是大致的操作了。然后注意的是每次我们需要不断更改当前树大小(这个好做直接用size赋值)

还有就是我们要每次在getroot之前把根结点的值赋成0,毕竟各个过程是相对独立的,可以避免很多不必要的麻烦。

总复杂度大概是O(nlog^2n),更多的细节看一下代码。

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<iostream>
#include<queue>
#include<set>
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define per(i,n,a) for(int i = n;i >= a;i--)
#define enter putchar('
')

using namespace std;
typedef long long ll;
const int M = 100005;

int read()
{
    int ans = 0,op = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9')
    {
        if(ch == '-') op = -1;
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9')
    {
        ans *= 10;
        ans += ch - '0';
        ch = getchar(); 
    }
    return ans * op;
}

struct edge
{
    int next,to,v;
}e[M];

int n,k,head[M],dis[M],ecnt,size[M],maxs[M],root,x,y,z,sum,tot,cur[M],ans;
bool vis[M];

void add(int x,int y,int z)
{
    e[++ecnt].to = y;
    e[ecnt].v = z;
    e[ecnt].next = head[x];
    head[x] = ecnt;
}

void getroot(int x,int fa)//求树的重心
{
    size[x] = 1,maxs[x] = 0;
    for(int i = head[x];i;i = e[i].next)
    {
        int t = e[i].to;
        if(t == fa || vis[t]) continue;
        getroot(t,x);
        size[x] += size[t];
        maxs[x] = max(maxs[x],size[t]);
    }
    maxs[x] = max(maxs[x],sum - size[x]);
    if(maxs[x] < maxs[root]) root = x;
}

void getdis(int x,int fa,int leng)//leng记录当前的长度
{
    cur[++tot] = leng;
    for(int i = head[x];i;i = e[i].next)
    {
        int t = e[i].to;
        if(t == fa || vis[t]) continue;
        getdis(t,x,leng + e[i].v);
    }
}

int calc(int x,int leng)
{
    tot = 0;
    getdis(x,0,leng);
    sort(cur+1,cur+1+tot);
    int l = 1,r = tot,temp = 0;
    while(l < r)//排序之后从两头开始找。因为对于每两个点如果其符合,那么对于l节点,从l+1~r全部是合法的点对。
    {
        if(cur[l] + cur[r] <= k) temp += r - l,l++;
        else r--;
    }
    rep(i,1,tot) cur[i] = 0;//注意这里不能使用memset,否则会超时
    return temp;
}

void solve(int x)
{
    vis[x] = 1;ans += calc(x,0);//计算当前子的答案
    for(int i = head[x];i;i = e[i].next)
    {
        int t = e[i].to;
        if(vis[t]) continue;
        ans -= calc(t,e[i].v);//减去每棵子树的答案,注意这里必须把初始的长度传进去,否则你相当于统计的路径长度少了一截
        sum = size[t],maxs[root = 0] = n;
        getroot(t,0),solve(root);//继续找重心之后递归求解
    }
}

int main()
{
    n = read();
    rep(i,1,n-1) x = read(),y = read(),z = read(),add(x,y,z),add(y,x,z);
    k = read();
    sum = maxs[root] = n,getroot(1,0);//找到树的重心
    solve(root);//开始递归求解
    printf("%d
",ans);
    return 0; 
} 
原文地址:https://www.cnblogs.com/captain1/p/9644181.html