P2634 [国家集训队]聪聪可可

传送门

点分治入门题

首先可以直接枚举所有两点的lca强行dp

设 $f [ x ] [ 0/1/2 ]$ 表示节点 $x$ 在模3意义下,$x$ 的子树所有节点到 $x$ 的距离为 $0/1/2$ 时的方案数

初始 $f [ x ] [ 0 ] =1$ (本身到自己有一种方案)

转移就枚举所有儿子 $v$ ,设 $x$ 到 $v$ 的距离为 $w$,那么转移显然为:

$ f [ x ] [ (w+0)\%3 ] += f [ v ] [ 0 ] $ 

$ f [ x ] [ (w+1)\%3 ] += f [ v ] [ 1 ] $ 

$ f [ x ] [ (w+2)\%3 ] += f [ v ] [ 2 ] $ 

统计答案也十分显然,对 $w$ 分类讨论一下就好了:

inline void work(int x)//注意函数名不是"dfs",x就是我们枚举的lca
{
    f[x][0]=1; f[x][1]=f[x][2]=0;
    for(int i=fir[x];i;i=from[i])
    {
        int &v=to[i],&w=val[i]; if(vis[v]) continue; dfs(v,x);//dfs求出儿子的f
        if(w==0) ans+=f[x][0]*f[v][0]+f[x][1]*f[v][2]+f[x][2]*f[v][1];
        if(w==1) ans+=f[x][0]*f[v][2]+f[x][1]*f[v][1]+f[x][2]*f[v][0];
        if(w==2) ans+=f[x][0]*f[v][1]+f[x][1]*f[v][0]+f[x][2]*f[v][2];
     //注意先统计ans再转移f f[x][w]
+=f[v][0]; f[x][fk(w+1)]+=f[v][1]; f[x][fk(w+2)]+=f[v][2]; } }

但是最坏情况会被卡到 $O(n^2)$

所以上点分治,每次找重心作lca,这样每次子树大小至少减半

枚举lca复杂度$O(n)$,搞dp因为子树大小每次减半所以复杂度约为 $O(log_n)$

总复杂度 $O(nlog_n)$

注意long long

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<vector>
using namespace std;
typedef long long ll;
inline int read()
{
    int x=0,f=1; char ch=getchar();
    while(ch<'0'||ch>'9') { if(ch=='-') f=-1; ch=getchar(); }
    while(ch>='0'&&ch<='9') { x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); }
    return x*f;
}
const int N=5e5+7,INF=1e9+7;
int fir[N],from[N<<1],to[N<<1],val[N<<1],cntt;
inline void add(int &a,int &b,int &c)
{
    from[++cntt]=fir[a]; fir[a]=cntt;
    to[cntt]=b; val[cntt]=c;
}
inline int fk(int x) { return x>=3 ? x-3 : x; }
int n,rt,tot;
ll ans,f[N][3];
int sz[N],mx[N];
bool vis[N];
void find_rt(int x,int fa)//找重心
{
    mx[x]=0; sz[x]=1;
    for(int i=fir[x];i;i=from[i])
    {
        int &v=to[i]; if(vis[v]||v==fa) continue;
        find_rt(v,x); sz[x]+=sz[v];
        mx[x]=max(mx[x],sz[v]);
    }
    mx[x]=max(mx[x],tot-sz[x]);
    if(mx[x]<mx[rt]) rt=x;
}
void dfs(int x,int fa)//dfs先求出子树的f
{
    f[x][0]=1; f[x][1]=f[x][2]=0;
    for(int i=fir[x];i;i=from[i])
    {
        int &v=to[i],&w=val[i]; if(vis[v]||v==fa) continue; dfs(v,x);
        f[x][w]+=f[v][0];
        f[x][fk(w+1)]+=f[v][1];
        f[x][fk(w+2)]+=f[v][2];
    }
}
inline void work(int x)//统计答案
{
    f[x][0]=1; f[x][1]=f[x][2]=0;
    for(int i=fir[x];i;i=from[i])
    {
        int &v=to[i],&w=val[i]; if(vis[v]) continue; dfs(v,x);
        if(w==0) ans+=f[x][0]*f[v][0]+f[x][1]*f[v][2]+f[x][2]*f[v][1];
        if(w==1) ans+=f[x][0]*f[v][2]+f[x][1]*f[v][1]+f[x][2]*f[v][0];
        if(w==2) ans+=f[x][0]*f[v][1]+f[x][1]*f[v][0]+f[x][2]*f[v][2];
        f[x][w]+=f[v][0];
        f[x][fk(w+1)]+=f[v][1];
        f[x][fk(w+2)]+=f[v][2];
    }
}
void solve(int x)//点分治
{
    vis[x]=1; work(x);
    for(int i=fir[x];i;i=from[i])
    {
        int &v=to[i]; if(vis[v]) continue;
        tot=sz[v]; rt=0;
        find_rt(v,0); solve(rt);
    }
}
inline ll gcd(ll a,ll b) { return b ? gcd(b,a%b) : a; }
int main()
{
    //freopen("data.in","r",stdin);
    //freopen("data.out","w",stdout);
    int a,b,c;
    n=read();
    for(int i=1;i<n;i++)
    {
        a=read(),b=read(),c=read()%3;
        add(a,b,c); add(b,a,c);
    }
    tot=n; mx[rt]=INF;
    find_rt(1,0); solve(rt);
    ans=ans*2+n; ll d=gcd(ans,1ll*n*n);
    printf("%lld/%lld",ans/d,1ll*n*n/d);
    return 0;
}
点分治解法

其实此题不用点分治

可以直接树形dp,转移同上...代码又短又好写

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<vector>
using namespace std;
typedef long long ll;
inline int read()
{
    int x=0,f=1; char ch=getchar();
    while(ch<'0'||ch>'9') { if(ch=='-') f=-1; ch=getchar(); }
    while(ch>='0'&&ch<='9') { x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); }
    return x*f;
}
const int N=5e5+7;
inline int fk(int x) { return x>=3 ? x-3 : x; }
int fir[N],from[N<<1],to[N<<1],val[N<<1],cntt;
inline void add(int &a,int &b,int &c)
{
    from[++cntt]=fir[a]; fir[a]=cntt;
    to[cntt]=b; val[cntt]=c;
}
ll ans,f[N][3];
void dfs(int x,int fa)
{
    f[x][0]=1;
    for(int i=fir[x];i;i=from[i])
    {
        int &v=to[i],&w=val[i]; if(v==fa) continue;
        dfs(v,x);
        if(w==0) ans+=f[x][0]*f[v][0]+f[x][1]*f[v][2]+f[x][2]*f[v][1];
        if(w==1) ans+=f[x][0]*f[v][2]+f[x][1]*f[v][1]+f[x][2]*f[v][0];
        if(w==2) ans+=f[x][0]*f[v][1]+f[x][1]*f[v][0]+f[x][2]*f[v][2];
        f[x][w]+=f[v][0];
        f[x][fk(w+1)]+=f[v][1];
        f[x][fk(w+2)]+=f[v][2];
    }
}
int n;
ll gcd(ll a,ll b) { return b ? gcd(b,a%b) : a; }
int main()
{
    int a,b,c;
    n=read();
    for(int i=1;i<n;i++)
    {
        a=read(),b=read(),c=read()%3;
        add(a,b,c); add(b,a,c);
    }
    dfs(1,1); ans<<=1; ans+=n;
    ll d=gcd(ans,1ll*n*n);
    printf("%lld/%lld",ans/d,1ll*n*n/d);
    return 0;
}
原文地址:https://www.cnblogs.com/LLTYYC/p/10302855.html