[BZOJ 3697] 采药人的路径

[题目链接]

         https://www.lydsy.com/JudgeOnline/problem.php?id=3697

[算法]

        首先 , 将黑色的边变成1 ,白色的边变成-1

        那么 , 问题就转化为了有多少条路径满足 :

        1. 路径长度为0 

        2. 路径中间存在一个点使得这个点可以将这条路径分成两段且长度为0

        考虑我们已经处理完了前面的子树 , 对于当前子树中一点x , 深度为d , 显然 , 前面的子树中要有一个深度为-d的点y ,这条路径合法当且仅当x到分治重心的路径上有一点满足d[z] = d[x]或y到分治重心的路径上有一点满足d[z] = d[y]

        那么我们可以记s[i][0 / 1]表示前面的子树中深度为i ,不存在/存在它到当前分治重心路径上的一个点,使得它的深度= i的点的个数

        考虑对于当前深度为d的点x,显然至少有s[−d][1]个点会和x形成一条合法的路径。 如果分治重心到x的路径上存在一个深度为d的点,那么还会有s[−d][0]个点会和x形成一条合法路径。

        详见代码 , 时间复杂度 : O(NlogN)

[代码]

         

#include<bits/stdc++.h>
using namespace std;
#define MAXN 200010
typedef long long LL;
const int inf = 2e9;

struct edge
{
        int to , w , nxt;
} e[MAXN << 1];

int n , tot , root , len;
int head[MAXN] , size[MAXN] , weight[MAXN] , cnt[MAXN << 1];
int s[MAXN << 1][2];
bool visited[MAXN];
LL ans;

template <typename T> inline void chkmax(T &x,T y) { x = max(x,y); }
template <typename T> inline void chkmin(T &x,T y) { x = min(x,y); }
template <typename T> inline void read(T &x)
{
    T f = 1; x = 0;
    char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
    for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0';
    x *= f;
}
inline void addedge(int u,int v,int w)
{
        tot++;
        e[tot] = (edge){v , w , head[u]};
        head[u] = tot;
}
inline void getroot(int u , int fa , int total)
{
        weight[u] = 0;
        size[u] = 1;
        for (int i = head[u]; i; i = e[i].nxt)
        {
                int v = e[i].to;
                if (v == fa || visited[v]) continue;
                getroot(v , u , total);
                size[u] += size[v];
                chkmax(weight[u] , size[v]);
        }        
        chkmax(weight[u] , total - size[u]);
        if (weight[u] < weight[root]) root = u;
}
inline void calc(int u , int fa , int dep)
{
        ans += s[n - dep][1];
        if (cnt[n + dep] > 0) ans += s[n - dep][0];
        if (dep == 0 && cnt[n] > 1) ++ans;
        ++cnt[n + dep];
        for (int i = head[u]; i; i = e[i].nxt)
        {
                int v = e[i].to , w = e[i].w;
                if (v == fa || visited[v]) continue;
                calc(v , u , dep + w);        
        } 
        --cnt[n + dep];
}
inline void update(int u , int fa , int dep)
{
        ++s[n + dep][cnt[n + dep] > 0];
        ++cnt[n + dep];
        for (int i = head[u]; i; i = e[i].nxt)
        {
                int v = e[i].to , w = e[i].w;
                if (v == fa || visited[v]) continue;
                update(v , u , dep + w);
        }
        --cnt[n + dep];
}
inline void clear(int u , int fa , int dep)
{
        size[u] = 1;
        --s[n + dep][cnt[n + dep] > 0];
        ++cnt[n + dep];
        for (int i = head[u]; i; i = e[i].nxt)
        {
                int v = e[i].to , w = e[i].w;
                if (v == fa || visited[v]) continue;
                clear(v , u , dep + w);
                size[u] += size[v];
        }
        --cnt[n + dep];
}
inline void work(int u)
{
        visited[u] = true;
        cnt[n] = 1;
        for (int i = head[u]; i; i = e[i].nxt)
        {
                int v = e[i].to , w = e[i].w;
                if (visited[v]) continue;
                calc(v , u , w);
                update(v , u , w);        
        }        
        for (int i = head[u]; i; i = e[i].nxt)
        {
                int v = e[i].to , w = e[i].w;
                if (visited[v]) continue;
                clear(v , u , w);
        }
        for (int i = head[u]; i; i = e[i].nxt)
        {
                int v = e[i].to;
                if (visited[v]) continue;
                root = 0;
                getroot(v , 0 , size[v]);
                work(root);
        }
}

int main()
{
        
        read(n);
        for (int i = 1; i < n; i++)
        {
                int u , v , w;
                read(u); read(v); read(w);
                w = (w == 0) ? -1 : 1;
                addedge(u , v , w);
                addedge(v , u , w);
        }
        weight[root = 0] = inf;
        getroot(1 , 0 , n);
        work(root);
        printf("%lld
",ans);
        
        return 0;
    
}
原文地址:https://www.cnblogs.com/evenbao/p/9899056.html