Codeforces GYM 101968 A. Tree Game

差点自闭,感谢大佬帮忙找bug

题目:https://codeforces.com/gym/101968/problem/A

找树的重心+思维

找到树的重心,如果重心只有一个,以重心为根节点dfs,求各节点深度,那么任意一对节点都符合题意,只要让先手mark深度小的节点即可,其中同样深度的节点,交换位置扔符合题意,要加上dep[i]*(dep[i]-1)/2

如果重心有两个,分别以两个重心为根节点dfs,以任意一对节点都符合题意为基础(ans=n*(n-1)/2),如果两个深度相同的节点分别挂在两个不同的重心上,那么他们无法作为符合题意的一对节点,要减去这一部分,对每个重心深度相同的节点,分别加上交换节点扔符合题意的一部分,嗯就这样

#include<iostream>
#include<cstdio> 
#include<cmath>
#include<queue>
#include<vector>
#include<string.h>
#include<cstring>
#include<algorithm>
#include<set>
#include<stack>
#include<map>
#include<fstream>
#include<cstdlib>
#include<ctime>
#include<list>
#include<climits>
#include<bitset>
#include<random>
using namespace std;
#define fopen freopen("input.in", "r", stdin);freopen("output.in", "w", stdout);
#define left asfdasdasdfasdfsdfasfsdfasfdas1
#define right asfdasdasdfasdfsdfasfsdfasfdas2
#define y1 asfdasdasdfasdfsdfasfsdfasfdas3
typedef long long ll;
typedef unsigned ui;
typedef long double ld;
const int dell[8][2]={{1,2},{1,-2},{2,1},{2,-1},{-1,2},{-1,-2},{-2,1},{-2,-1}};
ll mod=1e9+7;
const ll inf=(1LL<<31)-1;
const int maxn=1e6+7;
const int maxm=1e6+7;
const double eps=1e-8;
const double pi=acos(-1.0);
const int csize=22;
int n,k,m,ar[maxn];
struct node{
    int b,nex;
}no[maxn*2];
int head[maxn],sz,mx,root;
int pre[maxn];
ll sspre[maxn];
int pre2[maxn];
ll sspre2[maxn];
void init(){
    for(int i=0;i<=n;i++)head[i]=-1;
    sz=0;
}
void add(int a,int b){
    no[sz].b=b;
    no[sz].nex=head[a];
    head[a]=sz++;
}
int dep[maxn],num[maxn];
void dfs(int u,int fa)
{
    num[u]=1;
    if(dep[u]>mx){
        mx=dep[u];
        root=u;
    }
    for(int i=head[u];i!=-1;i=no[i].nex){
        int v=no[i].b;
        if(v==fa)continue;
        dep[v]=dep[u]+1;
        dfs(v,u);
        num[u] += num[v];
    }
}
bool findmid(int u,int& mid){
    bool can=1;
    for(int i=head[u];i!=-1;i=no[i].nex){
        int v=no[i].b;
        if(dep[v]==dep[u]+1){
            if(findmid(v,mid)){
                return 1;
            }
            can &= num[v]<(n+1)/2;
        }
    }
    if(can && num[u]>=(n+1)/2){
        mid=u;
        return 1;
    }
    else return 0;
}

int main()
{
    //fopen
    //freopen("input.in","r",stdin);
    int t;scanf("%d",&t);
    while(t--){
        scanf("%d",&n);
        init(); 
        for(int i=2;i<=n;i++){
            int x;scanf("%d",&x);
            add(i,x);
            add(x,i);
        }
        mx=0;root=0;
        dep[1]=1;
        dfs(1,-1);
        int mid=0;
        findmid(1,mid);
        ll ans=1LL*n*(n-1)/2;
        if(n%2==0 && num[mid]==n/2){
            int mid2=0;
            for(int i=head[mid];i!=-1;i=no[i].nex){
                if(dep[no[i].b]+1==dep[mid]){
                    mid2=no[i].b;
                    break;
                }
            }
            for(int i=1;i<=n;i++)dep[i]=0;
            dep[mid2]=1;
            dfs(mid2,mid);
            for(int i=0;i<=n+1;i++)pre[i]=pre2[i]=0;
            for(int i=1;i<=n;i++){
                if(dep[i]>0)pre2[dep[i]]++;
            }

            for(int i=1;i<=n;i++)dep[i]=0;
            dep[mid]=1;
            dfs(mid,mid2);
            for(int i=1;i<=n;i++){
                if(dep[i]>0){
                    pre[dep[i]]++;
                }
            }
            for(int i=1;i<=n;i++){
                ans -= (ll)pre[i]*pre2[i];
            }
            for(int i=1;i<=n;i++){
                ans += (ll)pre[i]*(pre[i]-1)/2;
                ans += (ll)pre2[i]*(pre2[i]-1)/2;
            }
        }
        else{
            //if(fir==1609)while(1);
            dep[mid]=1;
            dfs(mid,-1);
            for(int i=0;i<=n+1;i++)pre[i]=0;
            for(int i=1;i<=n;i++)pre[dep[i]]++;
            for(int i=1;i<=n;i++){
                ans += (ll)pre[i]*(pre[i]-1)/2;
            }
        }
        printf("%lld
",ans);
    }
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/wa007/p/9964458.html