BZOJ 3611: [Heoi2014]大工程 [虚树 DP]

传送门

题意:

多次询问,求最长链最短链链总长


煞笔$DP$记录$d,c,f,g$

$MD$该死拍了一下午没问题然后交上去就$T$

然后发现树链剖分写成$size[v]+=size[u]$

我想知道我随机生成的大数据是怎么跑过去的!!!!!!!!

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
typedef long long ll;
const int N=1e6+5,INF=1e9;
inline int read(){
    char c=getchar();int x=0,f=1;
    while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}
    return x*f;
}

int n,Q;
struct Edge{
    int v,ne,w;
}e[N<<1];
int cnt,h[N];
inline void ins(int u,int v){
    cnt++;
    e[cnt].v=v;e[cnt].ne=h[u];h[u]=cnt;
    cnt++;
    e[cnt].v=u;e[cnt].ne=h[v];h[v]=cnt;
}
int deep[N];
inline void ins2(int u,int v){
    cnt++;
    e[cnt].v=v;e[cnt].ne=h[u];h[u]=cnt;
    e[cnt].w=deep[v]-deep[u];
}

int dfn[N],dfc,top[N],size[N],mx[N],fa[N];
void dfs(int u){
    size[u]++;
    for(int i=h[u];i;i=e[i].ne){
        int v=e[i].v;
        if(v==fa[u]) continue;
        fa[v]=u;deep[v]=deep[u]+1;
        dfs(v);
        size[u]+=size[v];
        if(size[v]>size[mx[u]]) mx[u]=v;
    }
}
void dfs2(int u,int anc){
    dfn[u]=++dfc; top[u]=anc;
    if(mx[u]) dfs2(mx[u],anc);
    for(int i=h[u];i;i=e[i].ne)
        if(e[i].v!=fa[u] && e[i].v!=mx[u]) dfs2(e[i].v,e[i].v);
}
inline int lca(int x,int y){
    while(top[x]!=top[y]){
        if(deep[top[x]]<deep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    return deep[x]<deep[y] ? x : y;
}

int key[N],c[N],f[N],g[N];
ll d[N],Sum;
int Max,Min;
void dp(int u){//printf("dp %d
",u);
    d[u]=0;
    if(key[u]) c[u]=1,f[u]=g[u]=0;
    else c[u]=0,f[u]=-INF,g[u]=INF;

    for(int i=h[u];i;i=e[i].ne){
        int v=e[i].v,w=e[i].w;
        dp(v);
        Sum+=d[u]*c[v]+c[u]*(d[v]+(ll)c[v]*w);
        Max=max(Max,f[u]+f[v]+w);
        Min=min(Min,g[u]+g[v]+w);

        d[u]+=d[v]+c[v]*w;
        c[u]+=c[v];
        f[u]=max(f[u],f[v]+w);
        g[u]=min(g[u],g[v]+w);
    }
    h[u]=0;
}
int a[N];
inline bool cmp(int x,int y){return dfn[x]<dfn[y];}
int st[N];
void solve(){
    cnt=0;
    int n=read();//printf("n %d
",n);
    for(int i=1;i<=n;i++) a[i]=read(),key[a[i]]=1;
    sort(a+1,a+1+n,cmp);
    int top=0;
    for(int i=1;i<=n;i++){
        if(!top) {st[++top]=a[i];continue;}
        int x=a[i],f=lca(x,st[top]);
        while(dfn[f]<dfn[st[top]]){
            if(dfn[f]>=dfn[st[top-1]]){
                ins2(f,st[top--]);
                if(f!=st[top]) st[++top]=f;
                break;
            }else ins2(st[top-1],st[top]),top--;
        }
        st[++top]=x;
    }
    while(top>1) ins2(st[top-1],st[top]),top--;
    Sum=0;Max=-INF;Min=INF;
    dp(st[1]);
    for(int i=1;i<=n;i++) key[a[i]]=0;
    printf("%lld %d %d
",Sum,Min,Max);
}
int main(){
    freopen("in","r",stdin);
    n=read();
    for(int i=1;i<n;i++) ins(read(),read());
    dfs(1);dfs2(1,1);
    memset(h,0,sizeof(h));
    Q=read();
    while(Q--) solve();
}
原文地址:https://www.cnblogs.com/candy99/p/6526913.html