清北学堂模拟赛d4t5 b

分析:一眼树形dp题,就是不会写QAQ.树形dp嘛,定义状态肯定有一维是以i为根的子树,其实这道题只需要这一维就可以了.设f[i]为以i为根的子树中的权值和.先处理子树内部的情况,用一个数组son[i]表示以i为根的子树中,i能走到的节点个数,可以利用son数组和当前点的权值来更新f数组.

     处理了每个子树内部的情况,接下来就要合并它们,将每一个根节点作为中间点,算一下中间点权值的贡献,利用乘法原理算出有多少对点对经过中间点,乘一下就ok了.

树形dp的基本状态定义要熟记,有些题目子树内部是互相独立的,可以在子树里面单独计算,最后再合并一下.

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

const int maxn = 300010;

int n, a[maxn], head[maxn], to[maxn * 2], nextt[maxn * 2], tot = 1, w[maxn * 2];
long long ans, f[maxn], son[maxn];

void add(int x, int y, int z)
{
    w[tot] = z;
    to[tot] = y;
    nextt[tot] = head[x];
    head[x] = tot++;
}

void dfs(int u, int fa, int col)
{
    long long res = 0;
    f[u] = a[u];
    son[u] = 1;
    bool flag = 1;
    for (int i = head[u]; i; i = nextt[i])
    {
        int v = to[i];
        if (v == fa)
            continue;
            dfs(v, u, w[i]);
            if (col != w[i])
            {
                flag = 0;
                son[u] += son[v];
                f[u] += son[v] * a[u] + f[v];
            }
            res += son[v] * a[u] + f[v];
    }
    ans += res;
    if (flag)
        return;
    for (int i = head[u]; i; i = nextt[i])
    {
        int v1 = to[i];
        if (v1 != fa)
            for (int j = i; j; j = nextt[j]) //防止重复统计,所以j=i而不是j=head[u]
            {
                int v2 = to[j];
                if (v2 != fa && w[i] != w[j])
                    ans += son[v1] * f[v2] + son[v2] * f[v1] + a[u] * son[v1] * son[v2];
            }
    }
}

int main()
{
    scanf("%d", &n);
    for (int i = 1; i <= n; i++)
        scanf("%d", &a[i]);
    for (int i = 1; i < n; i++)
    {
        int x, y, z;
        scanf("%d%d%d", &x, &y, &z);
        add(x, y, z);
        add(y, x, z);
    }
    dfs(1, 0, 0);
    printf("%lld
", ans);

    return 0;
}
原文地址:https://www.cnblogs.com/zbtrs/p/7646580.html