luogu P5002 专心OI

题目描述

Imakf是一个小蒟蒻,他最近刚学了LCA,他在手机APP里看到一个游戏也叫做LCA就下载了下来。

这个游戏会给出你一棵树,这棵树有N个节点,根结点是R,系统会选中M个点P1,P2...PM,要Imakf回答有多少组点对(ui,vi)的最近公共祖先是Pi。Imakf是个小蒟蒻,他就算学了LCA也做不出,于是只好求助您了。

Imakf毕竟学过一点OI,所以他允许您把答案模 (10^9+7)

输入格式

第一行 N , R , M

此后N-1行 每行两个数a,b 表示a,b之间有一条边

此后1行M个数 表示P_i

输出格式

M行,每行一个数,第i行的数表示有多少组点对(u_i,v_i)的最近公共祖先是P_i

N≤10000,M≤50000


显然这题根LCA没有多大关系......

设size(x)表示以x为根的子树大小(x自己也要算),我们来愉快地推式子

设存在点u,v,并且x是u,v的祖先。考虑u,v之间的路径(不包含u,v)不经过x,那么根据往日做LCA的经验,只有当u,v至少其中一个等于x时两个点的LCA才会是x。设x一共有k棵子树,那么此时:

[ans1=sum_{i=1}^{k}size[son[i]]*2+1=size[x]*2-1 ]

再考虑经过x的情况,此时:

[ans2=sum_{i=1}^{k}sum_{j=1}^{k}size[son[i]]*size[son[j]] ]

[ans2=sum_{i=1}^{k}size[son[i]]*(size[x]-1) ]

[ans2=(size[x]-1)^2 ]

然后减去重复计算的i=j的部分:

[ans2=(size[x]-1)^2-sum_{i=1}^{k}size[i]^2 ]

再把两个答案加起来:

[ans=ans1+ans2=size[x]*2-1+(size[x]-1)^2-sum_{i=1}^{k}size[i]^1 ]

[ans=size[x]^2-sum_{i=1}^{k}size[i]^2 ]

然后我们来分析复杂度。最坏的情况就是:根直接连接其余所有点,并且每次询问都是根节点。此时时间复杂度就是O(N*M)。考虑优化。

显然重复计算过的我们不需要再算。记录ans数组,预处理出每个点的答案,时间复杂度就变成了O(N+M),期望得分100。

#include<iostream>
#include<cstring>
#include<cstdio>
#define maxn 10001
#define p 1000000007
using namespace std;

struct edge{
    int to,next;
    edge(){}
    edge(const int &_to,const int &_next){ to=_to,next=_next; }
}e[maxn<<1];
int head[maxn],k;

int size[maxn],ans[maxn];
int n,m,r;

inline int read(){
    register int x(0),f(1); register char c(getchar());
    while(c<'0'||'9'<c){ if(c=='-') f=-1; c=getchar(); }
    while('0'<=c&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar();
    return x*f;
}
inline void add(const int &u,const int &v){
    e[k]=edge(v,head[u]);
    head[u]=k++;
}

inline void dfs(int u,int pre){
    size[u]=1;
    for(register int i=head[u];~i;i=e[i].next){
        int v=e[i].to;
        if(v==pre) continue;
        dfs(v,u),size[u]+=size[v];
        ans[u]=(ans[u]+size[v]*size[v])%p;
    }
    ans[u]=(size[u]*size[u]%p-ans[u]+p)%p;
}

int main(){
    memset(head,-1,sizeof head);
    n=read(),r=read(),m=read();
    for(register int i=1;i<n;i++){
        int u=read(),v=read();
        add(u,v),add(v,u);
    }
    dfs(r,0);

    while(m--) printf("%d
",ans[read()]);
    return 0;
}

*相减的部分取余需要判负数......或者直接加个p上去

原文地址:https://www.cnblogs.com/akura/p/10837547.html