[POJ 3417] Network

[题目链接]

          http://poj.org/problem?id=3417

[算法]

         树上差分

[代码]

       

#include <algorithm>  
#include <bitset>  
#include <cctype>  
#include <cerrno>  
#include <clocale>  
#include <cmath>  
#include <complex>  
#include <cstdio>  
#include <cstdlib>  
#include <cstring>  
#include <ctime>  
#include <deque>  
#include <exception>  
#include <fstream>  
#include <functional>  
#include <limits>  
#include <list>  
#include <map>  
#include <iomanip>  
#include <ios>  
#include <iosfwd>  
#include <iostream>  
#include <istream>  
#include <ostream>  
#include <queue>  
#include <set>  
#include <sstream>  
#include <stdexcept>  
#include <streambuf>  
#include <string>  
#include <utility>  
#include <vector>  
#include <cwchar>  
#include <cwctype>  
#include <stack>  
#include <limits.h>
using namespace std;
#define MAXN 100010
#define MAXLOG 20

struct edge
{
        int to,nxt;
} e[MAXN << 1];

int i,n,m,tot,u,v,ans;
int sum[MAXN],dep[MAXN],head[MAXN];
int anc[MAXN][MAXLOG];

inline void addedge(int u,int v)
{
        tot++;
        e[tot] = (edge){v,head[u]};
        head[u] = tot;
}
inline void dfs1(int u)
{
        int i,v;
        for (i = 1; i < MAXLOG; i++)
        {
                if (dep[u] < (1 << i)) break; 
                anc[u][i] = anc[anc[u][i - 1]][i - 1];
        }
        for (i = head[u]; i; i = e[i].nxt)
        {
                v = e[i].to;
                if (v != anc[u][0])
                {
                        dep[v] = dep[u] + 1;
                        anc[v][0] = u;
                        dfs1(v);
                }
        }
}
inline void dfs2(int u)
{
        int i,v;
        for (i = head[u]; i; i = e[i].nxt)
        {
                v = e[i].to;
                if (v == anc[u][0]) continue;
                dfs2(v);
                sum[u] += sum[v];
        }
}
inline int lca(int u,int v)
{
        int i,t;
        if (dep[u] > dep[v]) swap(u,v);
        t = dep[v] - dep[u];
        for (i = 0; i < MAXLOG; i++)
        {
                if (t & (1 << i))
                        v = anc[v][i];
        }
        if (u == v) return u;
        for (i = MAXLOG - 1; i >= 0; i--)
        {
                if (anc[u][i] != anc[v][i])
                {
                        u = anc[u][i];
                        v = anc[v][i];
                }
        }
        return anc[u][0];
} 

int main() 
{
        
        while (scanf("%d%d",&n,&m) != EOF)
        {
                tot = 0;
                memset(anc,0,sizeof(anc));
                memset(dep,0,sizeof(dep));
                for (i = 1; i <= n; i++) 
                {
                        head[i] = 0;
                        sum[i] = 0;
                }
                for (i = 1; i < n; i++)
                {
                        scanf("%d%d",&u,&v);
                        addedge(u,v);
                        addedge(v,u);
                }
                dfs1(1);
                for (i = 1; i <= m; i++)
                {
                        scanf("%d%d",&u,&v);
                        sum[u]++; sum[v]++;
                        sum[lca(u,v)] -= 2;
                }
                dfs2(1);
                ans = 0;
                for (i = 2; i <= n; i++)
                {
                        if (sum[i] == 0) ans += m;
                        if (sum[i] == 1) ans++;
                }
                printf("%d
",ans);
        }
        
        return 0;
    
}
原文地址:https://www.cnblogs.com/evenbao/p/9382565.html