树的点分治 板题 Luogu P3806

给定一棵有n个点的树

询问树上距离为k的点对是否存在。

AC code:

#include<bits/stdc++.h>
using namespace std;
const int MAXN = 10005;
const int MAXM = 105;
const int MAXK = 10000005;
int n, m, q[MAXM];
int fir[MAXN], to[MAXN<<1], nxt[MAXN<<1], wt[MAXN<<1], cnt;
inline void read(int &num)
{
    char ch; int flag=1;
    while(!isdigit(ch=getchar()))if(ch=='-')flag=-flag;
    for(num=ch-'0';isdigit(ch=getchar());num=num*10+ch-'0');
    num*=flag;
}
inline void Add(int u, int v, int w) { to[++cnt] = v; nxt[cnt] = fir[u]; fir[u] = cnt; wt[cnt] = w; }

int total, root, mx[MAXN], dis[MAXN], sz[MAXN];
bool vis[MAXN], Ans[MAXK], Exist[MAXK];

inline bool chkmax(int &x, int y) { return y > x ? x = y, 1 : 0; }

void getroot(int u, int ff)
{
    sz[u] = 1, mx[u] = 0;
    for(int i = fir[u]; i; i = nxt[i])
        if(to[i] != ff && !vis[to[i]])
            getroot(to[i], u), sz[u] += sz[to[i]], chkmax(mx[u], sz[to[i]]);
    chkmax(mx[u], total-sz[u]);
    if(mx[u] < mx[root]) root = u;
}
int stk[MAXN], indx;
inline void dfs(int u, int ff)
{
    stk[++indx] = dis[u];
    for(int i = fir[u]; i; i = nxt[i])
        if(to[i] != ff && !vis[to[i]])
            dis[to[i]] = dis[u] + wt[i], dfs(to[i], u);
}

int bin[MAXN], Cnt;
inline void solve(int u)
{
    Exist[0] = 1;
    for(int i = fir[u]; i; i = nxt[i])
        if(!vis[to[i]])
        {
            dis[to[i]] = wt[i], dfs(to[i], u);
            for(int j = 1; j <= indx; j++)
                for(int k = 1; k <= m; k++)
                    if(q[k] >= stk[j]) Ans[k] |= Exist[q[k]-stk[j]];
            while(indx) Exist[stk[indx]] = 1, bin[++Cnt] = stk[indx--];
        }
    while(Cnt) Exist[bin[Cnt--]] = 0;//注意这里memset会超时
}

inline void divide(int u)
{
    solve(u);
    vis[u] = 1;
    for(int i = fir[u]; i; i = nxt[i])
        if(!vis[to[i]])
        {
            total = sz[to[i]], root = 0;
            getroot(to[i], u), divide(to[i]);
        }
}

int main ()
{
    read(n), read(m);
    int x, y, z;
    for(int i = 1; i < n; i++)
        read(x), read(y), read(z), Add(x, y, z), Add(y, x, z);
    for(int i = 1; i <= m; i++) read(q[i]);
    total = n; mx[root=0] = n;
    getroot(1, 0); divide(root);
    for(int i = 1; i <= m; i++)
        puts(Ans[i] ? "AYE" : "NAY");
}
原文地址:https://www.cnblogs.com/Orz-IE/p/12039482.html