2019 沈阳网络赛 D Fish eating fruit ( 树形DP)

题目传送门

题意:求一颗树中所有点对(a,b)的路径长度,路径长度按照模3之后的值进行分类,最后分别求每一类的和

分析:树形DP

(dp[i][j]) 表示以 i 为根的子树中,所有子节点到 i 的路径长度模3等于 j 的路径之和
(c[i][j]) 表示以 i 为根的子树中,所有子节点到 i 的路径长度模3等于 j 的点数
(ok[i][j]) 表示以 i 为根的子树中,是否有子节点到 i 的路径长度模3等于 j

每次只考虑所有经过根 x 的路径,并且路径的一个端点在 x 的一颗子树上,另一个端点在 x 的另一颗子树上。(可以想到其他所有情况都可以在考虑 x 的子树结点或者是x的祖先结点时被考虑到)
假设当前枚举到 x 的子节点 y,之前遍历的子节点已经使得三个数组更新。那么我们假设要计算的路径的起点在 y ,要计算的路径的终点在之前遍历过的子节点中。

计算答案贡献:
关于x-y的连边的贡献为
(c[x][a] * c[y][b] * edge)
关于起点到 y 的所有路径长度的贡献为
(c[x][a] * dp[y][b])
关于x到终点的所有路径长度的贡献为
(c[y][b] * dp[x][a])

最终边权所属分类为((a+b+edge) \% 3) 累加到答案即可

关于更新 x
用 y 来更新 x
(dp[x][(a+edge)\%3] += dp[y][a] + edge * c[y][a])
(ok[x][(a+edge)\%3] = true)
(c[x][(a+edge)\%3] += c[y][a])

当然点分治也可以做,但是复杂度就不是很优秀了

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 10010;
const int M = 200010;
const ll mod = 1e9 + 7;
int head[N],ver[M],nxt[M],tot;
int x,y,n;
bool ok[N][3];
ll edge[M],z,dp[N][3],c[N][3];
ll ans[3];
void add(int x,int y,ll z){
    ver[++tot] = y;edge[tot] = z;nxt[tot] = head[x];head[x] = tot;
}
void dfs(int x,int fa){
    for(int i=head[x];i;i=nxt[i]){
        int y = ver[i];
        if(y == fa)continue;
        dfs(y,x);
        ll z = edge[i];
        for(int j=0;j<3;j++){
            for(int k=0;k<3;k++){
                if(ok[x][j] && ok[y][k]){
                    ans[(j+k+z)%3] += (dp[x][j] * c[y][k]% mod + dp[y][k] * c[x][j] % mod) % mod;
                    ans[(j+k+z) % 3] += z * c[x][j] * c[y][k] % mod;
                    ans[(j+k+z)%3] %= mod;
                }
            }
        }
        for(int j=0;j<3;j++){
            if(ok[y][j]){
                dp[x][(j+z) % 3] += dp[y][j] + z * c[y][j] % mod;
                dp[x][(j+z) % 3] %= mod;
                c[x][(j+z)%3] += c[y][j];
                ok[x][(j+z) % 3] = true;
            }
        }
    }
}
int main(){
    while(~scanf("%d",&n)){
        ans[0] = ans[1] = ans[2] = 0;
        tot = 0;
        for(int i=1;i<=n;i++){
            dp[i][0] = dp[i][1] = dp[i][2] = 0;
            c[i][1] = c[i][2] = 0;
            c[i][0] = 1;
            ok[i][0] = true;
            ok[i][1] = ok[i][2] = false;
            head[i] = 0;
        }
        for(int i=1;i<n;i++){
            scanf("%d%d%lld",&x,&y,&z);
            x ++;
            y ++;
            add(x,y,z);
            add(y,x,z);
        }
        dfs(1,0);
        printf("%lld %lld %lld
",ans[0] * 2 % mod,ans[1] * 2 % mod,ans[2] * 2 % mod);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/1625--H/p/11521171.html