hihocoder #1112 树上的好路径


时间限制:1000ms
单点时限:1000ms
内存限制:256MB

描述

现在有一棵有N个带权顶点的树,顶点编号为1,2,...,N。我们定义一条路径的次小(最小)权为它经过的所有顶点(包括起点和终点)中权值次小(最小)顶点的权值。现在给定常数c,你需要求出:存在多少个使得u<v的顶点组(u,v),满足从u到v的最短路的次小权恰为c但最小权不为c。
输入

第一行有两个数N和c。(1<=n<=100000)

第二行N个数,依次表示每个顶点的权值。

接下来N-1行,每行两个数,代表这棵树的一条边所连接的两个顶点的编号。

我们保证输入中的数都在int以内。
输出

一个数,为答案。
样例输入

    8 2
    2 2 3 3 1 2 3 2
    1 2
    3 2
    3 8
    4 2
    5 2
    5 6
    6 7

样例输出

    17


Solution

为了方便, 把我们要考虑的树记作$T=(V, E)$, 用$w[u]$表示节点$u$ ($uin V$) 的权值.

先考虑一个简化的问题:

求最小权小于$c$且次小权不小于$c$的路径$(u, v)$的数目.

为了解决这个问题, 我们考虑如下的添边过程:

我们考虑一个动态的图$S(V, E'), E'subseteq E$.

从$S=(V, emptyset)$开始, 先把所有满足$w[u]ge c land w[v] ge c$的边$(u, v)$加到$S$中,

然后考虑满足

[w[u]<c land w[v]ge c lor w[u]ge c land w[v] <c]

的边$(u, v)$, 不失一般性, 不妨设 $w[u]<c, w[v]ge c$.

我们先把$u$固定为$u_0$, 考虑将所有符合上述条件的边${(u_0, v)}$加到$s$中将能获得多少满足条件的路径.

显然这些满足条件的路径上的最小权就是$w[u_0]$.

(未完待续...)

(无力写了, 先把代码贴上)


UPD

前面写得太罗嗦了, 结果现在自己都看不大懂了. 其实做法一句话就能说清楚:

最小权小于$c$, 次小权不小于$c$的路径数 $-$ 最小权小于$c$, 次小权大于$c$的路径数

Implementation

  1 #include <bits/stdc++.h>
  2 using namespace std;
  3 using LL=long long;
  4 const int N{1<<17};
  5 
  6 int a[N];
  7 
  8 struct edge{
  9     int u, v;
 10     void read(){
 11         cin>>u>>v;
 12     }
 13 }e[N];
 14 
 15 struct DSU{
 16     int par[N], size[N];
 17     int n;
 18     DSU(int n):n(n){}
 19     void init(){
 20         for(int i=1; i<=n; i++){
 21             par[i]=i;
 22             size[i]=1;
 23         }
 24     }
 25     int find(int x){
 26         return x==par[x]?x:par[x]=find(par[x]);
 27     }
 28     void unite(int x, int y){
 29         x=find(x), y=find(y);
 30         if(x!=y) par[x]=y, size[y]+=size[x];
 31     }
 32 };
 33 
 34 vector<int> f[N];
 35 
 36 void prep(DSU &b, int n, int c){
 37     b.init();
 38     for(int i=1; i<=n; i++) f[i].clear();
 39     for(int i=1; i<n; i++){
 40         int u=e[i].u, v=e[i].v;
 41         if(a[u]>=c && a[v]>=c){
 42             b.unite(u, v);
 43         }
 44     }
 45 }
 46 
 47 int main(){
 48     int n, c;
 49     cin>>n>>c;
 50     DSU b(n);
 51 
 52     for(int i=1; i<=n; i++)
 53         cin>>a[i];
 54     for(int i=1; i<n; i++)
 55         e[i].read();
 56 
 57 
 58     LL res=0;
 59 
 60     prep(b, n, c);
 61 
 62     for(int i=1; i<n; i++){
 63         int u=e[i].u, v=e[i].v;
 64         if(a[u]<c ^ a[v]<c){    //tricky
 65             // cout<<u<<' '<<v<<endl;
 66             if(a[v]<c) swap(u, v);
 67             int rv=b.find(v);
 68             // res+=LL(b.size[u])*LL(b.size[v]);
 69             // if(ru!=rv)
 70             f[u].push_back(b.size[rv]);
 71         }
 72     }
 73 
 74     for(int i=1; i<=n; i++){
 75         // if(f[i].size()) cout<<"#"<<i<<endl;
 76         LL sum=0, t=0;
 77         for(auto &x: f[i])
 78             sum+=x;
 79         for(auto &x: f[i]) t+=LL(x)*(sum-x);
 80         res+=t>>1;
 81         res+=sum;
 82     }
 83 
 84 
 85     prep(b, n, c+1);
 86 
 87     for(int i=1; i<n; i++){
 88         int u=e[i].u, v=e[i].v;
 89         if(a[u]<c && a[v]>c || a[u]>c && a[v]<c){    //tricky
 90             if(a[v]<c) swap(u, v);
 91             int rv=b.find(v);
 92             // res+=LL(b.size[u])*LL(b.size[v]);
 93             f[u].push_back(b.size[rv]);
 94         }
 95     }
 96 
 97     for(int i=1; i<=n; i++){
 98         LL sum=0, t=0;
 99         for(auto &x: f[i])
100             sum+=x;
101         for(auto &x: f[i]) t+=LL(x)*(sum-x);
102         res-=t>>1, res-=sum;
103     }
104 
105     cout<<res<<endl;
106 }
原文地址:https://www.cnblogs.com/Patt/p/5833483.html