P3565 由简单的树形dp 引入 长链刨分

  这道题感觉不太行 因为自己没想出来。

先说一下暴力吧,取三个点 让两两之间的距离相等怎么做呢,看起来是很复杂的样子的,但是仔细观察发现 答案出自一个点的儿子之间 或者儿子和父亲之间。

暴力枚举三个点然后 算两两点的距离 ST表的话 可以做到n^3 。

考虑 稍微暴力一点的解法 我们发现对于每个点我们统计的都是它的子树内部的答案和各个子树之间的答案以及各个子树之间及父亲之间的答案。

考虑枚举每一个点为中心 然后利用子树统计答案 具体我们发现这其实就是 完成了上述的过程。

复杂度n^2 。可以通过此题。非常的巧妙。比较暴力的解题。

//#include<bits/stdc++.h>
#include<iomanip>
#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<queue>
#include<deque>
#include<cmath>
#include<ctime>
#include<cstdlib>
#include<stack>
#include<algorithm>
#include<vector>
#include<cctype>
#include<utility>
#include<set>
#include<bitset>
#include<map>
#define INF 1000000000
#define ll long long
#define min(x,y) ((x)>(y)?(y):(x))
#define max(x,y) ((x)>(y)?(x):(y))
#define RI register ll
#define db double
#define EPS 1e-8
using namespace std;
char buf[1<<15],*fs,*ft;
inline char getc()
{
    return (fs==ft&&(ft=(fs=buf)+fread(buf,1,1<<15,stdin),fs==ft))?0:*fs++;
}
inline int read()
{
    int x=0,f=1;char ch=getc();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getc();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getc();}
    return x*f;
}
const int MAXN=5010;
int n,len,maxx;
int d[MAXN];
ll ans,f1[MAXN],f2[MAXN],f[MAXN];
int lin[MAXN],ver[MAXN<<1],nex[MAXN<<1];
inline void add(int x,int y)
{
    ver[++len]=y;
    nex[len]=lin[x];
    lin[x]=len;
}
inline void dfs(int x,int father)
{
    d[x]=d[father]+1;++f[d[x]];
    maxx=max(maxx,d[x]);
    for(int i=lin[x];i;i=nex[i])
    {
        int tn=ver[i];
        if(tn==father)continue;
        dfs(tn,x);
    }
}
int main()
{
    freopen("1.in","r",stdin);
    n=read();
    for(int i=1;i<n;++i)
    {
        int x,y;
        x=read();y=read();
        add(x,y);add(y,x);
    }
    for(int i=1;i<=n;++i)
    {
        memset(f1,0,sizeof(f1));
        memset(f2,0,sizeof(f2));
        for(int j=lin[i];j;j=nex[j])
        {
            int tn=ver[j];
            maxx=0;d[i]=0;
            dfs(tn,i);
            for(int k=1;k<=maxx;++k)
            {
                ans+=f2[k]*f[k];
                f2[k]+=f[k]*f1[k];
                f1[k]+=f[k];f[k]=0;
            }
        }
    }
    printf("%lld
",ans);
    return 0;
}
View Code

当然也有我的原始思路 树形dp一下 f[i][j] 表示以i为根距i距离为j的点的个数 这个很好求f[i][j]=f[tn][j-1];f[x][0]=1;

考虑如何统计答案 在这个地方我遇到了一点小困难显然的是 答案出自自己的子树和子树和父亲之间 至于子树内部的东西我们可以递归来求解。

如何统计答案是一个重难点,这里有一个比较神仙了状态我也没有想出来想要统计答案必然的我们要先得到 距离i为j的点对的个数再用单个点的个数来计算。

设g[i][j]表示点对 直接到LCA的距离为d 到i这个点距离为d-j的个数 看起来非常的绕 但是 也比较自然因为这样才能与我们的f[i][j] 相结合起来组成答案。

一些 细节没有考虑清楚导致 wa 了很多次 我在进行统计答案的时候不光只有根的g数组*子树的f数组 还应该有子树的g数组*根的f数组(这一步代表子树和子树之间是双向的。

当然其中也有到根的转移 故父亲的那个地方也考虑到了 所以 是正确的。

//#include<bits/stdc++.h>
#include<iomanip>
#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<queue>
#include<deque>
#include<cmath>
#include<ctime>
#include<cstdlib>
#include<stack>
#include<algorithm>
#include<vector>
#include<cctype>
#include<utility>
#include<set>
#include<bitset>
#include<map>
#define INF 1000000000
#define ll long long
#define min(x,y) ((x)>(y)?(y):(x))
#define max(x,y) ((x)>(y)?(x):(y))
#define RI register ll
#define db double
#define EPS 1e-8
using namespace std;
char buf[1<<15],*fs,*ft;
inline char getc()
{
    return (fs==ft&&(ft=(fs=buf)+fread(buf,1,1<<15,stdin),fs==ft))?0:*fs++;
}
inline int read()
{
    int x=0,f=1;char ch=getc();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getc();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getc();}
    return x*f;
}
const int MAXN=5010;
int n,len,maxx;
int d[MAXN];
ll ans;
ll f[MAXN][MAXN];//f[i][j]表示距i点距离为j时的点的个数显然有f[i][j]+=f[tn][j-1];
ll g[MAXN][MAXN];//g[i][j]表示点对距LCA的距离为d时距i点距离为d-j时的点对个数
//首先这里说明这个状态的必要性 子树内部的答案是递归处理的这个先不管
//自己子树与子树之间的答案 利用g[i][j]*f[i][j]来计算
//那么子树和父亲呢 显然 g[i][0] 就是讨论与父亲之间的答案的
//综上 解决的答案的统计 证毕。
//状态空间包涵整个问题 看起还很妙的样子。
int lin[MAXN],ver[MAXN<<1],nex[MAXN<<1];
inline void add(int x,int y)
{
    ver[++len]=y;
    nex[len]=lin[x];
    lin[x]=len;
}
inline void dfs(int x,int father)
{
    maxx=max(maxx,d[x]);
    for(int i=lin[x];i;i=nex[i])
    {
        int tn=ver[i];
        if(tn==father)continue;
        d[tn]=d[x]+1;
        dfs(tn,x);
    }
}
inline void dp(int x,int father)
{
    f[x][0]=1;
    for(int i=lin[x];i;i=nex[i])
    {
        int tn=ver[i];
        if(tn==father)continue;
        dp(tn,x);
        for(int j=maxx;j>=0;--j)
        {
            if(j-1>=0)
            {
                ans+=g[x][j]*f[tn][j-1];
                ans+=g[tn][j]*f[x][j-1];
                g[x][j]+=f[x][j]*f[tn][j-1];
                f[x][j]+=f[tn][j-1];
            }
            g[x][j]+=g[tn][j+1];
        }
    }
    //ans+=g[x][0];
}
int main()
{
    freopen("1.in","r",stdin);
    n=read();
    for(int i=1;i<n;++i)
    {
        int x,y;
        x=read();y=read();
        add(x,y);add(y,x);
    }
    dfs(1,0);
    dp(1,0);
    printf("%lld
",ans);
    return 0;
}
View Code

此题n<=5000 如果n是100000呢 怎么办n^2 挂掉的话我们 需要再次优化。

这就引入了我们经典的树上优化 长链刨分: 按照深度 划分轻重链 。

性质 1 所有链长的和是O(n)的。证明:所有点都在一条重链中 只被计算一次 因为链长总和是O(n);

性质 2 一个点的k次祖先y所在链的长度>=k 显然 

性质 3 一个点向上跳重链的次数不超过sqrt(n) 显然 ->1+2+3+...sqrt(n) 总和 *2>n;

有了这些性质我可以开始长链剖分 对于以上的问题,我们进行完长链刨分以后我们钦定 选取重儿子直接转移信息 轻儿子暴力转移信息。

那么 总复杂度 是 O(n) + sum(重链长度)  这样复杂度为O(n) 。

没人讲解 tmp 数组是干什么的 我也只能暂时性的意会一下。

//#include<bits/stdc++.h>
#include<iomanip>
#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<queue>
#include<deque>
#include<cmath>
#include<ctime>
#include<cstdlib>
#include<stack>
#include<algorithm>
#include<vector>
#include<cctype>
#include<utility>
#include<set>
#include<bitset>
#include<map>
#define INF 1000000000
#define ll long long
#define min(x,y) ((x)>(y)?(y):(x))
#define max(x,y) ((x)>(y)?(x):(y))
#define RI register ll
#define db double
#define EPS 1e-8
using namespace std;
char buf[1<<15],*fs,*ft;
inline char getc()
{
    return (fs==ft&&(ft=(fs=buf)+fread(buf,1,1<<15,stdin),fs==ft))?0:*fs++;
}
inline int read()
{
    int x=0,f=1;char ch=getc();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getc();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getc();}
    return x*f;
}
const int MAXN=5010;
int n,len,maxx;
int d[MAXN],son[MAXN];
ll ans;
ll *f[MAXN],tmp[MAXN<<2],*id=tmp;//f[i][j]表示距i点距离为j时的点的个数显然有f[i][j]+=f[tn][j-1];
ll *g[MAXN];//g[i][j]表示点对距LCA的距离为d时距i点距离为d-j时的点对个数
//首先这里说明这个状态的必要性 子树内部的答案是递归处理的这个先不管
//自己子树与子树之间的答案 利用g[i][j]*f[i][j]来计算
//那么子树和父亲呢 显然 g[i][0] 就是讨论与父亲之间的答案的
//综上 解决的答案的统计 证毕。
//状态空间包涵整个问题 看起还很妙的样子。
int lin[MAXN],ver[MAXN<<1],nex[MAXN<<1];
inline void add(int x,int y)
{
    ver[++len]=y;
    nex[len]=lin[x];
    lin[x]=len;
}
inline void dfs(int x,int father)
{
    for(int i=lin[x];i;i=nex[i])
    {
        int tn=ver[i];
        if(tn==father)continue;
        dfs(tn,x);
        if(d[tn]>d[son[x]])son[x]=tn;
    }
    d[x]=d[son[x]]+1;
}
inline void dp(int x,int father)
{
    if(son[x])f[son[x]]=f[x]+1,g[son[x]]=g[x]-1,dp(son[x],x);
    f[x][0]=1;ans+=g[x][0];
        for(int i=lin[x];i;i=nex[i])
    {
        int tn=ver[i];
        if(tn==father||tn==son[x])continue;
        f[tn]=id;id+=d[tn]<<1;g[tn]=id;id+=d[tn]<<1;
        dp(tn,x);
        for(int j=d[tn];j>=0;--j)
        {
            if(j-1>=0)
            {
                ans+=g[x][j]*f[tn][j-1];
                ans+=g[tn][j]*f[x][j-1];
                g[x][j]+=f[x][j]*f[tn][j-1];
                f[x][j]+=f[tn][j-1];
            }
            if(j+1<=d[tn])g[x][j]+=g[tn][j+1];
        }
    }
}
int main()
{
    freopen("1.in","r",stdin);
    n=read();
    for(int i=1;i<n;++i)
    {
        int x,y;
        x=read();y=read();
        add(x,y);add(y,x);
    }
    dfs(1,0);
    f[1]=id;id+=d[1]<<1;g[1]=id;id+=d[1]<<1;
    dp(1,0);
    printf("%lld
",ans);
    return 0;
}
View Code

这 就是O(n) 的长链剖分 加速dp

原文地址:https://www.cnblogs.com/chdy/p/11300692.html