【HDOJ6686】Rikka with Travels(树形DP)

题意:给定一棵n个点,边权为1的树,求有多少个有序数对(l1,l2)使得存在两条互不相交的路径,长度分别为l1和l2

n<=1e5

思路:

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 typedef long long ll;
  4 typedef unsigned int uint;
  5 typedef unsigned long long ull;
  6 typedef pair<int,int> PII;
  7 typedef pair<ll,ll> Pll;
  8 typedef vector<int> VI;
  9 typedef vector<PII> VII;
 10 #define N  310000
 11 #define M  4100000
 12 #define fi first
 13 #define se second
 14 #define MP make_pair
 15 #define pi acos(-1)
 16 #define mem(a,b) memset(a,b,sizeof(a))
 17 #define rep(i,a,b) for(int i=(int)a;i<=(int)b;i++)
 18 #define per(i,a,b) for(int i=(int)a;i>=(int)b;i--)
 19 #define lowbit(x) x&(-x)
 20 #define Rand (rand()*(1<<16)+rand())
 21 #define id(x) ((x)<=B?(x):m-n/(x)+1)
 22 #define ls p<<1
 23 #define rs p<<1|1
 24 
 25 const ll MOD=1e9+7,inv2=(MOD+1)/2;
 26       double eps=1e-6;
 27       int INF=1e9;
 28       int da[4]={-1,1,0,0};
 29       int db[4]={0,0,-1,1};
 30 
 31 
 32 int read()
 33 {
 34    int v=0,f=1;
 35    char c=getchar();
 36    while(c<48||57<c) {if(c=='-') f=-1; c=getchar();}
 37    while(48<=c&&c<=57) v=(v<<3)+v+v+c-48,c=getchar();
 38    return v*f;
 39 }
 40 
 41 struct data
 42 {
 43     int a,b;
 44 }f[N],g[N],t1[N],t2[N];
 45 
 46 int tot,ans[N],head[N],vet[N],nxt[N];
 47 
 48 data operator + (const data &a,const data &b)
 49 {
 50     return (data){max(a.a,b.a),max(a.a+b.a,max(a.b,b.b))};
 51 }
 52 
 53 data operator + (const data &a,const int &b)
 54 {
 55     return (data){a.a+b,max(a.a+b,a.b)};
 56 }
 57 
 58 void dfs1(int u,int fa)
 59 {
 60     int e=head[u];
 61     while(e)
 62     {
 63         int v=vet[e];
 64         if(v!=fa)
 65         {
 66             dfs1(v,u);
 67             f[u]=f[u]+(f[v]+1);
 68         }
 69         e=nxt[e];
 70     }
 71 
 72 }
 73 
 74 void dfs2(int u,int fa)
 75 {
 76     int s=0;
 77     int e=head[u];
 78     while(e)
 79     {
 80         int v=vet[e];
 81         if(v!=fa)
 82         {
 83             s++;
 84             t1[s]=f[v]+1;
 85             t2[s]=f[v]+1;
 86         }
 87         e=nxt[e];
 88     }
 89     rep(i,2,s) t1[i]=t1[i-1]+t1[i];
 90     per(i,s-1,1) t2[i]=t2[i+1]+t2[i];
 91     int i=0;
 92     e=head[u];
 93     while(e)
 94     {
 95         int v=vet[e];
 96         if(v!=fa)
 97         {
 98             i++;
 99             g[v]=g[u];
100             if(i>=2) g[v]=g[v]+t1[i-1];
101             if(i<=s-1) g[v]=g[v]+t2[i+1];
102             ans[g[v].b+1]=max(ans[g[v].b+1],f[v].b+1);
103             ans[f[v].b+1]=max(ans[f[v].b+1],g[v].b+1);
104             g[v]=g[v]+1;
105         }
106         e=nxt[e];
107     }
108     e=head[u];
109     while(e)
110     {
111         int v=vet[e];
112         if(v!=fa) dfs2(v,u);
113         e=nxt[e];
114     }
115 }
116 
117 void add(int a,int b)
118 {
119     nxt[++tot]=head[a];
120     vet[tot]=b;
121     head[a]=tot;
122 }
123 
124 int main()
125 {
126     //freopen("1.in","r",stdin);
127     //freopen("1.out","w",stdout);
128 
129     int cas;
130     scanf("%d",&cas);
131 
132     while(cas--)
133     {
134         int n=read();
135         tot=0;
136         rep(i,1,n) head[i]=0;
137         rep(i,1,n)
138         {
139             f[i].a=f[i].b=0;
140             g[i].a=g[i].b=0;
141             ans[i]=0;
142         }
143         rep(i,1,n-1)
144         {
145             int x=read(),y=read();
146             add(x,y);
147             add(y,x);
148         }
149         dfs1(1,0);
150         dfs2(1,0);
151         per(i,n-1,1) ans[i]=max(ans[i],ans[i+1]);
152         ll s=0;
153         rep(i,1,n) s+=ans[i];
154         printf("%I64d
",s);
155 
156 
157 
158     }
159 
160     return 0;
161 }
原文地址:https://www.cnblogs.com/myx12345/p/11666921.html