树分治learning

学习了树的点分治,树的边分治似乎因为复杂度过高而并不出众,于是没学

自己总结了一下 有些时候面对一些树上的结构 并且解决的是和路径有关的问题的时候 如果是多个询问 关注点在每次给出两个点,求一些关于这两个点之间路径的问题的时候,我们可以使用树链剖分,但是如果是给出一个单一的询问,但是很宏观 类似于求所有点对之间路径满足xx的数量,这时候我们可以树形dp做些什么 但是有时候会遇到一些树形dp难以解决的东西,类似于数组开不下,无法转移状态这种问题,就可以用树分治

树分治基于一个思想 先确定一个点 找到所有过这个点的路径并判断 再对被这个点分开的连通块做同样的操作

QAQ于是做了几道模板题

POJ1741 模板题 求一棵树中 满足点对之间路径加和小于k的数量

这里用到了一个动态规划 判断一个数组中 选两个数能加起来<k  做法是sort后维护两个指针

其余的地方都很模板,需要注意的是要对root的所有son进行solve之后再重置lr

#include<stdio.h>
#include<math.h>
#include<string.h>
#include<vector>
#include<queue>
#include<map>
#include<string>
#include<iostream>
#include<algorithm>
#include<stack>
using namespace std;
#define L long long
#define pb push_back
#define lala printf("--------
");
#define ph push
#define rep(i, a, b) for (int i=a;i<=b;++i)
#define dow(i, b, a) for (int i=b;i>=a;--i)
#define fmt(i,n) if(i==n)printf("
");else printf(" ") ;
#define fi first
#define se second
template<class T> inline void flc(T &A, int x){memset(A, x, sizeof(A));}
int read(){int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;}
int n , m ;
bool vis[10050] ;
int dis[10050] ;
///---
struct node {
    int v,w,nex;
}b[10050*2];
int tot ;
int head[10050] ;
void init() {
    flc(head,-1);
    tot = 0 ;
}
void add(int u,int v,int w) {
    tot ++ ;
    b[tot].v=v;
    b[tot].w=w;
    b[tot].nex=head[u];
    head[u]=tot;
}
///---
int son[10050] ;
int getsize(int u,int fa) {
    son[u] = 1 ;
    for(int i = head[u] ; i != -1 ; i = b[i].nex) {
        int v = b[i].v ;
        if(v==fa || vis[v]) continue ;
        son[u]+=getsize(v,u);
    }
    return son[u];
}
int minn ;
void getroot(int u,int fa,int &root,int siz) {
    int maxx = siz - son[u] ;
    for(int i = head[u] ; i != -1 ; i = b[i].nex) {
        int v=b[i].v ;
        if(v==fa || vis[v]) continue ;
        getroot(v,u,root,siz) ;
        maxx = max(maxx,son[v]) ;
    }
    if(minn == -1 || maxx < minn) {
         minn = maxx ;
         root = u ;
    }
}
///---
int l , r ;
void getdepth(int u,int fa,int xd) {
    dis[++r] = xd ;
    for(int i = head[u] ; i != -1 ; i = b[i].nex) {
        int v = b[i].v ;
        int w = b[i].w ;
        if(v == fa || vis[v]) continue ;
        getdepth(v , u , xd + w) ;
    }
}
bool cmp(int a, int b) {
    return a<b;
}
int getdep(int l , int r) {
    if(l >= r) return 0 ;
    sort(dis + l , dis + r + 1 , cmp ) ;
    int res = 0 ;
    int le = l ;
    int ri = l-1 ;
    while(ri+1 <= r && dis[ri+1] + dis[le] <= m) {
        ri ++ ;
        res ++ ;
    }
    while(le + 1 <= r) {
        le ++ ;
        while(ri >= l && dis[ri] + dis[le] > m) ri -- ;
        res += ri - l + 1 ;
    }

    for(int i = l ; i <= r ; i ++ ) {
        if(dis[i]*2 <= m) res -- ;
    }
    return (res / 2) ;
}
///---
int solve(int u) {
    int siz = getsize(u , -1) ;
    minn = -1 ;
    int root = -1 ;
    getroot(u , -1 , root , siz) ;
    vis[root] = true ;
    int res = 0 ;
    for(int i = head[root] ; i != -1 ; i = b[i].nex) {
        int v = b[i].v ;
        if(vis[v]) continue ;
        int z = solve(v) ;
        res += z ;
    }
    l = 1 ;
    r = 0 ;
    for(int i = head[root] ; i != -1 ; i = b[i].nex) {
        int v = b[i].v ;
        int w = b[i].w ;
        if(vis[v]) continue ;
        getdepth(v , root , w) ;
        res -= getdep(l , r) ;
        l = r + 1 ;
    }
    res += getdep(1 , r) ;
    for(int i = 1 ; i <= r ; i ++ ) {
        if(dis[i] <= m) res ++ ;
        else break ;
    }
    vis[root] = false ;
    return res ;
}

int main () {
    while(scanf("%d%d" , &n, &m) != EOF) {
        if(n==0&&m==0) break ;
        init() ;
        rep(i,1,n-1) {
            int u,v,w;
            u=read();v=read();w=read();
            add(u,v,w);
            add(v,u,w);
        }
        memset(vis,false,sizeof(vis));
        int ans = solve(1) ;
        printf("%d
" , ans) ;
    }
}

BZOJ 2152 求路径%3==0的点对的数量

因为只是%3 所以比较容易些。。如果写树形dp的话会比较好写 维护一个dp[n][3]的数组就可以

但是如果不是3是很大的数字 就得开dp[n][m] 如果开不下的话就得树分治

/// 树形dp跑得又好又快QAQ

#include<stdio.h>
#include<math.h>
#include<string.h>
#include<vector>
#include<queue>
#include<map>
#include<string>
#include<iostream>
#include<algorithm>
#include<stack>
using namespace std;
#define L long long
#define pb push_back
#define lala printf("--------
");
#define ph push
#define rep(i, a, b) for (int i=a;i<=b;++i)
#define dow(i, b, a) for (int i=b;i>=a;--i)
#define fmt(i,n) if(i==n)printf("
");else printf(" ") ;
#define fi first
#define se second
template<class T> inline void flc(T &A, int x){memset(A, x, sizeof(A));}
int read(){int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;}
int n ;
struct node {
    int v,w,nex;
}b[20050 * 2];
int head[20050];
int tot ;
void add(int u,int v,int w) {
    tot++;
    b[tot].v=v;b[tot].w=w;
    b[tot].nex=head[u];head[u]=tot;
}
void init() {
    flc(head,-1);
    tot=0;
}
 
int dp[20050][5] ;
int ans ;
 
void dfs(int u,int fa) {
    int a[4];
    flc(a,0) ;
    a[0] = 1 ;
    for(int i=head[u];i!=-1;i=b[i].nex) {
        int v=b[i].v ;
        int w=b[i].w ;
        if(v==fa) continue ;
        dfs(v,u) ;
        rep(j,0,2) {
            dp[u][(j+w)%3]+=dp[v][j] ;
        }
        rep(j,0,2) {
            int z=w+j;
            z%=3 ;
            if(z==0) {
                ans += a[0]*dp[v][j] ;
            }
            if(z==1) {
                ans += a[2]*dp[v][j] ;
            }
            if(z==2) {
                ans += a[1]*dp[v][j] ;
            }
        }
        rep(j,0,2) {
            a[(j+w)%3]+=dp[v][j] ;
        }
    }
    dp[u][0] ++ ;
}
 
int main (){
    while(scanf("%d" , &n) != EOF) {
        init() ;
        flc(dp,0);
        ans = 0 ;
        rep(i,1,n-1){
            int u=read(),v=read(),w=read();
            add(u,v,w);add(v,u,w);
        }
        dfs(1,-1);
        int fm = n*n;
        ans *= 2 ;
        ans += n ;
        int gc = __gcd(fm,ans) ;
        fm/=gc ;
        ans/=gc ;
        printf("%d/%d",ans,fm) ;
    }
}
#include<stdio.h>
#include<math.h>
#include<string.h>
#include<vector>
#include<queue>
#include<map>
#include<string>
#include<iostream>
#include<algorithm>
#include<stack>
using namespace std;
#define L long long
#define pb push_back
#define lala printf("--------
");
#define ph push
#define rep(i, a, b) for (int i=a;i<=b;++i)
#define dow(i, b, a) for (int i=b;i>=a;--i)
#define rnode(i,u) for(int i = head[u] ; i != -1 ; i = b[i].nex)
#define fmt(i,n) if(i==n)printf("
");else printf(" ") ;
#define fi first
#define se second
template<class T> inline void flc(T &A, int x){memset(A, x, sizeof(A));}
int read(){int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;}
 
int n , m ;
int ans ;
bool vis[20050] ;
int dis[20050] ;
///---
struct node {
    int v,w,nex;
}b[20050*2];
int tot ;
int head[20050] ;
void init() {
    flc(head,-1);
    tot = 0 ;
}
void add(int u,int v,int w) {
    tot ++ ;
    b[tot].v=v;
    b[tot].w=w;
    b[tot].nex=head[u];
    head[u]=tot;
}
///---
int siz[20050];
int getsize(int u,int fa) {
    siz[u] = 1 ;
    for(int i=head[u];i!=-1;i=b[i].nex) {
        int v=b[i].v;
        if(v==fa||vis[v]) continue ;
        siz[u]+=getsize(v,u);
    }
    return siz[u];
}
int minn ;
void getroot(int u,int fa,int num,int &root) {
    int maxx=0;
    for(int i=head[u];i!=-1;i=b[i].nex){
        int v=b[i].v ;
        if(v==fa||vis[v]) continue ;
        getroot(v,u,num,root);
        maxx=max(maxx,siz[v]);
    }
    maxx=max(maxx,num-siz[u]);
    if(maxx<minn){
        minn=maxx;root=u;
    }
}
///---
int l,r;
void getdepth(int u,int fa,int xd) {
    dis[++r]=xd ;
    rnode(i,u) {
        int v=b[i].v ;
        if(v==fa||vis[v]) continue ;
        int w=b[i].w ;
        getdepth(v,u,xd+w) ;
    }
}
int getdep(int l,int r) {
    if(l>r) return 0 ;
    int a[3] ; flc(a,0) ;
    rep(i,l,r) {
        a[dis[i]%3] ++ ;
    }
    int res = 0 ;
    rep(i,0,2) {
        rep(j,0,2) {
            if((i+j)%3==0) res += a[i]*a[j] ;
        }
    }
    return res ;
}
///---
int solve(int u) {
    int num = getsize(u,-1);
    minn = 999999999 ;
    int root ;
    getroot(u,-1,num,root);
    int ans = 0 ;
    vis[root]=true;
    rnode(i,root) {
        int v=b[i].v;
        if(vis[v]) continue ;
        ans += solve(v) ;
    }
    l = 1 ;
    r = 0 ;
    rnode(i,root) {
        int v=b[i].v;
        int w=b[i].w;
        if(vis[v]) continue ;
        getdepth(v,root,w) ;
        ans -= getdep(l,r) ;
        l = r + 1 ;
    }
    dis[++r] = 0 ;
    ans += getdep(1,r) ;
    vis[root] = false ;
    return ans ;
}
 
 
 
 
int main () {
    while(scanf("%d" , &n) != EOF) {
        init() ;
        rep(i,1,n-1) {
            int u=read();int v=read() ; int w = read();
            add(u,v,w) ; add(v,u,w) ;
        }
        memset(vis,false,sizeof(vis)) ;
        int ans = solve(1) ;
        int fm = n*n ;
        int g = __gcd(ans,fm) ;
        fm/=g ;
        ans/=g ;
        printf("%d/%d
" , ans , fm) ;
    }
}

HDU 5977 大连的铜牌题 求包含所有颜色的路径的数目 k<=10

这个题的颜色来源于点 在对root的son进行getdepth的时候 需要把root的颜色给带下去 因为我们用ans-root的同一个son内的孩子 里面肯定是包含root的颜色的 这个关系是或 所以可以直接或上去

#include<stdio.h>
#include<math.h>
#include<string.h>
#include<vector>
#include<queue>
#include<map>
#include<string>
#include<iostream>
#include<algorithm>
#include<stack>
using namespace std;
#define L long long
#define pb push_back
#define lala printf("--------
");
#define ph push
#define rep(i, a, b) for (L i=a;i<=b;++i)
#define dow(i, b, a) for (L i=b;i>=a;--i)
#define rnode(i,u) for(L i = head[u] ; i != -1 ; i = b[i].nex)
#define fmt(i,n) if(i==n)printf("
");else printf(" ") ;
#define fi first
#define se second
template<class T> inline void flc(T &A, L x){memset(A, x, sizeof(A));}
L read(){L x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;}

L n , k , m ;
L ans ;
bool vis[50050] ;
L bl[50050] ;
///---
struct node {
    L v,nex;
}b[50050*2];
L tot ;
L head[50050] ;
L dis[50050] ;
void init() {
    flc(head,-1);
    tot = 0 ;
    memset(vis,false,sizeof(vis)) ;
}
void add(L u,L v) {
    tot ++ ;
    b[tot].v=v;
    b[tot].nex=head[u];
    head[u]=tot;
}
///---
vector<int>q[2050] ;
void thefirst() {
    L z = (1<<k)-1 ;
    rep(i,0,1024) q[i].clear() ;
    rep(i,0,z) {
        rep(j,0,z) {
            if( (i|j) == z) {
                q[i].pb(j) ;
            }
        }
    }
}
///---
L siz[50050];
L getsize(L u,L fa) {
    siz[u] = 1 ;
    rnode(i,u){
        L v=b[i].v;
        if(v==fa||vis[v]) continue ;
        siz[u]+=getsize(v,u);
    }
    return siz[u];
}
L minn ;
void getroot(L u,L fa,L num,L &root) {
    L maxx=0;
    rnode(i,u){
        L v=b[i].v ;
        if(v==fa||vis[v]) continue ;
        getroot(v,u,num,root);
        maxx=max(maxx,siz[v]);
    }
    maxx=max(maxx,num-siz[u]);
    if(minn==-1||maxx<minn){
        minn=maxx;root=u;
    }
}
///---
L l , r ;
void getdepth(L u,L fa,L xd) {
    xd |= (1 << (bl[u]-1)) ;
    dis[++r] = xd ;
    rnode(i,u) {
        L v=b[i].v;
        if(vis[v] || v==fa) continue ;
        getdepth(v,u,xd) ;
    }
}
L mp[2050] ;
L getdep(L l , L r) {
    if(l>r) return 0 ;
    flc(mp,0) ;
    L ans = 0 ;
    rep(i,l,r) {
        L x=dis[i] ;
        for(L i=0;i<q[x].size();i++){
            L y=q[x][i];
            ans += mp[y] ;
        }
        mp[x] ++ ;
    }
    return ans ;
}
///---
L solve(L u) {
    L siz = getsize(u,-1) ;
    L root = -1;
    minn = -1 ;
    getroot(u,-1,siz,root) ;
    vis[root]=true ;
    L ans = 0 ;
    rnode(i,root) {
        L v=b[i].v;
        if(vis[v]) continue ;
        ans += solve(v) ;
    }
    l = 1 ;
    r = 0 ;
    rnode(i,root) {
        L v=b[i].v ;
        if(vis[v]) continue ;
        getdepth(v,root,(1<<(bl[root]-1))) ;
        ans -= getdep(l,r) ;
        l = r + 1 ;
    }
    L x = (1<<(bl[root]-1)) ;
    L K = (1<<k)-1 ;
    rep(i,1,r) {
        if((x | dis[i]) == K) {
            ans ++ ;
        }
    }
    ans += getdep(1,r) ;
    vis[root]=false;
    return ans ;
}

int main () {
    while(scanf("%lld%lld" , &n,&k) != EOF) {
        init() ;
        thefirst() ;
        rep(i,1,n) bl[i] = read() ;
        rep(i,1,n-1) {
            L u=read(),v=read();
            add(u,v) ;
            add(v,u) ;
        }
        if(k == 1) {
            printf("%lld
" , n*n) ;
            continue ;
        }
        L ans = solve(1) ;
        printf("%lld
" , ans*2) ;
    }
}

学会了树分治之后开启了新技能“看见什么不明显DP的树上结构就觉得可以树分治” 感觉要分治算法学傻。。

训练赛看到一个题 感觉树形DP不可做 于是想树分治 发现解决不了这个问题 但是感觉还是树分治 赛后发现果然

uvaLive 6900 给出一棵树 每条边有cost与val 我有C 在树上选一条路径出来 使sum(cost) <= C时的最大val

这个和加减不太一样 因为加减是可以通过对root的son来操作进行去重的 上一个大连的是进行或运算 也无可厚非 但是这个求max 是不可逆的

但是我们本来就不需要去重 和以前模板思路不一样的是 我们保存dis数组中 每一个值来自哪个root的儿子R 然后对R排序 处理完一个R再搞另一个R 我们不需要排序 因为根据dfs的特性 相同的R一定有且只有一段 所以不需要sort 和之前的去重没有什么时间上的差别 因为省去了去重的时间 所以我想 时间应该会更快

在第一道题里面 用一个sort+O(n)单调思想 其实sort就撑到nlogn了 所以之后的nlogn也是可以接受的 可以做一个树状数组 来维护前缀max

因为不能开太大 所以进行一个离散化 时间也是nlogn的 最后的复杂度还是nlognlogn 虽然常数大点

这种思想是泛用的 之前的几道题也可以这么做

uvaLive 6900

#include<stdio.h>
#include<math.h>
#include<string.h>
#include<vector>
#include<queue>
#include<map>
#include<string>
#include<iostream>
#include<algorithm>
#include<stack>
using namespace std;
#define L long long
#define pb push_back
#define lala printf("--------
");
#define ph push
#define rep(i, a, b) for (int i=a;i<=b;++i)
#define dow(i, b, a) for (int i=b;i>=a;--i)
#define fmt(i,n) if(i==n)printf("
");else printf(" ") ;
#define rnode(i,u) for(int i = head[u] ; i != -1 ; i = b[i].nex)
#define fi first
#define se second
template<class T> inline void flc(T &A, int x){memset(A, x, sizeof(A));}
int read(){int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;}
int n , m;
struct node {
    int vol,val;
    int R ;
}dis[20050];
bool vis[20050];
struct no {
    int v,vol,val,nex;
}b[20050*2];
int head[20050];
int tot;
void init() {
    flc(head,-1);
    tot=0;
}
void add(int u,int v,int vol,int val) {
    tot++ ;
    b[tot].v=v;b[tot].vol=vol;b[tot].val=val;
    b[tot].nex=head[u] ; head[u]=tot;
}
int V ;
///---
int son[20050] ;
int getsize(int u,int fa) {
    son[u] = 1 ;
    rnode(i,u) {
        int v=b[i].v;
        if(v==fa || vis[v]) continue ;
        son[u] += getsize(v,u) ;
    }
    return son[u] ;
}
int minn ;
void getroot(int u,int fa,int &root,int siz) {
    int maxx = siz - son[u] ;
    rnode(i,u) {
        int v=b[i].v ;
        if(v==fa || vis[v]) continue ;
        if(son[v] > maxx) maxx = son[v] ;
        getroot(v,u,root,siz) ;
    }
    if(maxx < minn) {
        minn = maxx;
        root = u ;
    }
}
///---
int l,r ;
void getdepth(int u,int fa,int xdvol,int xdval,int sp) {
    node tmp ;
    tmp.vol = xdvol ;
    tmp.val = xdval ;
    tmp.R = sp ;
    dis[++r] = tmp ;
    rnode(i,u) {
        int v = b[i].v ;
        if(v == fa || vis[v]) continue ;
        getdepth(v,u,xdvol + b[i].vol,xdval + b[i].val,sp) ;
    }
}
int c[40050] ;
int lowbit(int x) {
    return (x&(-x)) ;
}
void segadd(int x,int ma) {
    while(x<=40000) {
        c[x]=max(c[x] , ma) ;
        x+=lowbit(x) ;
    }
}
int fin(int x) {
    int res = 0 ;
    while(x>0) {
        res = max(c[x],res);
        x-=lowbit(x) ;
    }
    return res ;
}
int calc(int l,int r) {
    if(l > r) return 0 ;
    flc(c,0) ;
    vector<int>ls ; ls.clear() ;
    rep(i,l,r) {
        ls.pb(dis[i].vol) ;
    }
    int res = 0 ;
    sort(ls.begin(),ls.end()) ;
    ls.erase(unique(ls.begin(),ls.end()) , ls.end()) ;
    for(int i = l ; i <= r ; i ++ ) {
        int j = i ;
        while(j <= r && dis[j].R == dis[i].R) {
            int z = dis[j].vol ;
            int val1 = dis[j].val ;
            if(z > V) {
                j ++ ;
                continue ;
            }
            int x = V - z ;
            int id = -2 ;
            int ll = 0 ;
            int rr = ls.size()-1 ;
            while(ll<=rr) {
                int mid=(ll+rr)/2 ;
                if(ls[mid]<=x) {
                    id=mid;
                    ll=mid+1;
                }
                else {
                    rr=mid-1;
                }
            }
            if(id==-2){
                j++;
                continue ;
            }
            int rres = fin(id+1) ;
            res = max(res , rres + dis[j].val) ;
            j ++ ;
        }
        j -- ;
        rep(k,i,j) {
            int vol = dis[k].vol ;
            int val = dis[k].val ;
            int id = lower_bound(ls.begin(),ls.end(),vol)-ls.begin()+1 ;
            segadd(id,val) ;
        }
        i = j ;
    }
    return res ;
}
///---
int solve(int u) {
    int siz = getsize(u,-1) ;
    minn = 999999999 ;
    int root ;
    getroot(u,-1,root,siz) ;
    vis[root] = true ;
    int res = 0 ;
    rnode(i,root) {
        int v=b[i].v ;
        if(vis[v]) continue ;
        int x = solve(v) ;
        res = max(res,x) ;
    }
    l = 1 ;
    r = 0 ;
    rnode(i,root) {
        int v = b[i].v ;
        int vol = b[i].vol ; int val = b[i].val ;
        if(vis[v]) continue ;
        getdepth(v,root,vol,val,v) ;
    }
    res = max(res , calc(1,r)) ;
    rep(i,1,r) {
        if(dis[i].vol <= V) {
            res = max(dis[i].val , res) ;
        }
    }
    vis[root] = false ;
    return res ;
}


int main () {
    int t = read();
    while(t -- ) {
        n = read();
        init() ;
        rep(i,2,n) {
            int u=read(),v=read(),vol=read(),val=read() ;
            add(u,v,vol,val);
            add(v,u,vol,val);
        }
        V = read() ;
        memset(vis,false,sizeof(vis));
        int ans=solve(1) ;
        printf("%d
" , ans) ;
    }
}

BZOJ 2152 用这种方法改了一下 发现由于必须sort 所以复杂度比之前的做法要多一个log

#include<stdio.h>
#include<math.h>
#include<string.h>
#include<vector>
#include<queue>
#include<map>
#include<string>
#include<iostream>
#include<algorithm>
#include<stack>
using namespace std;
#define L long long
#define pb push_back
#define lala printf("--------
");
#define ph push
#define rep(i, a, b) for (int i=a;i<=b;++i)
#define dow(i, b, a) for (int i=b;i>=a;--i)
#define rnode(i,u) for(int i = head[u] ; i != -1 ; i = b[i].nex)
#define fmt(i,n) if(i==n)printf("
");else printf(" ") ;
#define fi first
#define se second
template<class T> inline void flc(T &A, int x){memset(A, x, sizeof(A));}
int read(){int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;}
 
int n , m ;
int ans ;
bool vis[20050] ;
struct no {
    int x , R;
}dis[20050] ;
///---
struct node {
    int v,w,nex;
}b[20050*2];
int tot ;
int head[20050] ;
void init() {
    flc(head,-1);
    tot = 0 ;
}
void add(int u,int v,int w) {
    tot ++ ;
    b[tot].v=v;
    b[tot].w=w;
    b[tot].nex=head[u];
    head[u]=tot;
}
///---
int siz[20050];
int getsize(int u,int fa) {
    siz[u] = 1 ;
    for(int i=head[u];i!=-1;i=b[i].nex) {
        int v=b[i].v;
        if(v==fa||vis[v]) continue ;
        siz[u]+=getsize(v,u);
    }
    return siz[u];
}
int minn ;
void getroot(int u,int fa,int num,int &root) {
    int maxx=0;
    for(int i=head[u];i!=-1;i=b[i].nex){
        int v=b[i].v ;
        if(v==fa||vis[v]) continue ;
        getroot(v,u,num,root);
        maxx=max(maxx,siz[v]);
    }
    maxx=max(maxx,num-siz[u]);
    if(maxx<minn){
        minn=maxx;root=u;
    }
}
///---
int l,r;
void getdepth(int u,int fa,int xd,int sp) {
    no tmp ;
    tmp.x = xd ;
    tmp.R = sp ;
    dis[++r] = tmp ;
    rnode(i,u) {
        int v=b[i].v ;
        if(v==fa||vis[v]) continue ;
        int w=b[i].w ;
        getdepth(v,u,xd+w,sp) ;
    }
}
int getdep(int l,int r) {
    if(l>r) return 0;
    int res = 0 ;
    int a[5] ; flc(a,0) ;
    for(int i = l ; i <= r ; i ++ ) {
        int j = i ;
        while(j <= r && dis[j].R==dis[i].R) {
            int x = dis[j].x % 3 ;
            int ned = 3 - x ;
            ned %= 3 ;
            res += a[ned] ;
            j ++ ;
        }
        j -- ;
        rep(k,i,j) {
            int x = dis[k].x % 3 ;
            a[x] ++ ;
        }
        i = j ;
    }
    return res ;
}
///---
int solve(int u) {
    int num = getsize(u,-1);
    minn = 999999999 ;
    int root ;
    getroot(u,-1,num,root);
    int ans = 0 ;
    vis[root]=true;
    rnode(i,root) {
        int v=b[i].v;
        if(vis[v]) continue ;
        ans += solve(v) ;
    }
    l = 1 ;
    r = 0 ;
    rnode(i,root) {
        int v=b[i].v;
        int w=b[i].w;
        if(vis[v]) continue ;
        getdepth(v,root,w,v) ;
    }
    ans += getdep(1,r) ;
    rep(i,1,r) {
        if(dis[i].x % 3 == 0) ans ++ ;
    }
    vis[root] = false ;
    return ans ;
}
 
 
 
 
int main () {
    while(scanf("%d" , &n) != EOF) {
        init() ;
        rep(i,1,n-1) {
            int u=read();int v=read() ; int w = read();
            add(u,v,w) ; add(v,u,w) ;
        }
        memset(vis,false,sizeof(vis)) ;
        int ans = solve(1) ;
        ans *= 2;
        ans += n ;
        int fm = n*n ;
        int g = __gcd(ans,fm) ;
        fm/=g ;
        ans/=g ;
        printf("%d/%d
" , ans , fm) ;
    }
}
原文地址:https://www.cnblogs.com/rayrayrainrain/p/7598082.html