NOIP2017-清北学堂国庆模拟1-Tyvj4865 天天和树

原题
树的直径。严谨的证明还没搞懂。
或许三次BFS可以写个templete缩短代码?

#include <cstdio>
#include <cstring>
using namespace std;
const int MaxN = 1e5;

struct node {
  int vID;
  node *nx;
  node(): 
    vID(-1), nx(NULL) {}
  node(int init_vID, node* init_nx): 
    vID(init_vID), nx(init_nx) {}
};
node *arc[MaxN];

node *newedge; 
void add_edge(int x, int y) {
  newedge = new node(y, arc[x]);
  arc[x] = newedge;
  newedge = new node(x, arc[y]);
  arc[y] = newedge;
}

bool mark[MaxN];
int q1[MaxN], from[MaxN]; 
struct q2_ele {
  int vID, dist;
}q2[MaxN];

int main() {
  freopen("tree.in", "r", stdin);
  freopen("tree.out", "w", stdout);
  int n, i, x, y;
  scanf("%d", &n);
  for (i = 1; i < n; i++) {
    scanf("%d%d", &x, &y);
    x--; y--;
    add_edge(x, y);
  }
  
  //第一遍:找直径的一个端点
  int head, tail, u, s, curr;
  q1[0] = 0;
  mark[0] = true;
  for (head = 0, tail = 1; head != tail; head++) {
    curr = q1[head];
    for (node *adj = arc[curr]; adj != NULL; adj = adj->nx)
      if (!mark[adj->vID]) {
        mark[adj->vID] = true;
        q1[tail] = adj->vID;
        tail++;
      }
    if (head + 1 == tail)
      s = curr;
  }
  
  //第二遍:找出直径上的所有点
  int t, max_dist, curr_vID, curr_dist;
  memset(mark, false, sizeof(mark));
  q2[0] = (q2_ele) {s, 0};
  t = s; max_dist = 0;
  mark[s] = true;
  for (head = 0, tail = 1; head != tail; head++) {
    curr_vID = q2[head].vID; curr_dist = q2[head].dist;
    for (node *adj = arc[curr_vID]; adj != NULL; adj = adj->nx)
      if (!mark[adj->vID]) {
        mark[adj->vID] = true;
        q2[tail] = (q2_ele) {adj->vID, curr_dist + 1};
        from[adj->vID] = curr_vID;
        tail++;
      }
    if (curr_dist > max_dist) {
      t = curr_vID;
      max_dist = curr_dist;
    }
  }
  
  //第三遍:找到距离路径最远的结点
  int ans;
  memset(mark, false, sizeof(mark));
  for (i = t, tail = 0; i != s; i = from[i], tail++) {
    q2[tail] = (q2_ele) {i, 0};
    mark[i] = true;
  }
  mark[s] = true;
  for (head = 0; head != tail; head++) {
    curr_vID = q2[head].vID; curr_dist = q2[head].dist;
    for (node *adj = arc[curr_vID]; adj != NULL; adj = adj->nx)
      if (!mark[adj->vID]) {
        mark[adj->vID] = true;
        q2[tail] = (q2_ele) {adj->vID, curr_dist + 1};
        tail++;
      }
    if (head + 1 == tail)
      ans = curr_dist;
  }
  
  printf("%d
", ans);
  return 0;
}
原文地址:https://www.cnblogs.com/P6174/p/7545634.html