CF600E Lomsat gelral [dsu on tree]

Lomsat gelralLomsat gelral

题目描述见链接 .


color{red}{正解部分}

dsu on treedsu on tree, 每次将 重儿子 的信息继承给父亲, 轻儿子 暴力扫, 时间复杂度 O(NlogN)O(Nlog N) .


color{red}{实现部分}

  • dsu on treedsu on tree 先扫 轻儿子, 把 轻儿子 的答案全部计算出来 .
  • 再 扫 重儿子, 得到 重儿子 传上来的信息后, 再扫 轻儿子, 计算答案,
  • 根据参数决定是否将信息继承给父亲 .
#include<bits/stdc++.h>
#define reg register
typedef long long ll;

int read(){
        char c;
        int s = 0, flag = 1;
        while((c=getchar()) && !isdigit(c))
                if(c == '-'){ flag = -1, c = getchar(); break ; }
        while(isdigit(c)) s = s*10 + c-'0', c = getchar();
        return s * flag;
}

const int maxn = 1e5 + 10;

int N;
int num0;
int max_cnt;
int Fa[maxn];
int cnt[maxn];
int col[maxn];
int son[maxn];
int size[maxn];
int head[maxn];

ll curc;
ll Ans[maxn];

struct Edge{ int nxt, to; } edge[maxn << 1];

void Add(int from, int to){ edge[++ num0] = (Edge){ head[from], to }; head[from] = num0; }

void DFS_1(int k, int fa){
        size[k] = 1, Fa[k] = fa;
        for(reg int i = head[k]; i; i = edge[i].nxt){
                int to = edge[i].to;
                if(to == fa) continue ;
                DFS_1(to, k); size[k] += size[to];
                if(size[to] > size[son[k]]) son[k] = to;
        }
}

void scanson(int k, const int &opt, const int &sn){
        cnt[col[k]] += opt;
        if(cnt[col[k]] > max_cnt) max_cnt = cnt[col[k]], curc = col[k];
        else if(cnt[col[k]] == max_cnt) curc += col[k];
        for(reg int i = head[k]; i; i = edge[i].nxt){
                int to = edge[i].to;
                if(to == Fa[k] || to == sn) continue ;
                scanson(to, opt, sn);
        }
}

void DFS_2(int k, int fa, int opt){
        for(reg int i = head[k]; i; i = edge[i].nxt){
                int to = edge[i].to;
                if(to == fa || to == son[k]) continue ;
                DFS_2(to, k, 0);
        }
        max_cnt = curc = 0;
        if(son[k]) DFS_2(son[k], k, 1);
        scanson(k, 1, son[k]); 
        Ans[k] = curc;
        if(!opt) scanson(k, -1, 0);
}

int main(){
        N = read();
        for(reg int i = 1; i <= N; i ++) col[i] = read();
        for(reg int i = 1; i < N; i ++){
                int x = read(), y = read();
                Add(x, y), Add(y, x);
        }
        DFS_1(1, 0); DFS_2(1, 0, 1);
        for(reg int i = 1; i <= N; i ++) printf("%lld ", Ans[i]);
        return 0;
}
原文地址:https://www.cnblogs.com/zbr162/p/11822393.html