【ybt高效进阶 21165 / 150C】【nowcoder 1103B】树上交集 / 路径计数机(换根DP)(树形DP)

树上交集 / 路径计数机

题目链接:ybt高效进阶 21165 / 150C / nowcoder 1103B

题目大意

给你一棵树,问你能找到多少个四元组 (a,b,c,d),满足 a 到 b 边数为 p,c 到 d 边数为 q,而且两条路径没有交。

思路

考虑求不交比较难,我们搞有交的。

那不难想出两条路径就两种情况,一个是公用同一个 LCA,要么是有一条路径穿过了另一条路径的 LCA。
然后我们就以 LCA 为中心去搞,考虑求出这四个东西:(fp_i,fq_i,gp_i,gq_i),分别表示从 (i) 出发,经过 (p/q) 条边,然后到的是 (i) 子树内 / 外的点。

那不难想到所有的四元组个数就是:(sumlimits_{i=1}^n fp_isumlimits_{i=1}^n fq_i)
然后有交的:(sumlimits_{i=1}^n(fp_ifq_i+fp_igq_i+gp_ifq_i))
(后面两个都是第二个情况,只是谁穿过不同而已)

那接下来就是要求这四个数组。
那不难想出可以先搞路径一段是在 (i) 上的,然后把两个路径拼上得到上面的数组。
(两条路径的长度之和已经确定,就直接枚举一条的长度)

然后设 (f_{i,j},g_{i,j}) 为有多少条路径一端是 (i),另一端在 (i) 子树内 / 外,然后路径长度是 (j) 的路径数。
(f_{i,j}) 可以直接 DP 下去,(f_{i,j}=sum_{k=son_i} f_{k,j-1},f_{i,0}=1)

至于 (g_{i,j}),我们想,如果我们求出以 (i) 为根的时候的 (f_{i,j}),那这个就是不管子树内外的,减去在子树内的(一开始算出的 (f_{i,j})),就是在子树外的了。
那我们考虑换根一下就好了。(记得跑完换回来)

然后就好啦。

代码

#include<map>
#include<queue>
#include<cstdio>
#include<vector>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long

using namespace std;

struct node {
	int to, nxt;
}e[6001];
int n, p, q, x, y;
int le[3001], KK;
ll ans, f[3001][3001], g[3001][3001];//这个是一个端点是 i,另一个端点在 i 子树内 / 外,路径长度为 j 的个数
ll fp[3001], fq[3001], gp[3001], gq[3001];//我们要求的东西
ll nw[3001][3001];

void add(int x, int y) {
	e[++KK] = (node){y, le[x]}; le[x] = KK;
	e[++KK] = (node){x, le[y]}; le[y] = KK; 
}

void dfs_f(int now, int father) {
	f[now][0] = 1;
	for (int i = le[now]; i; i = e[i].nxt)
		if (e[i].to != father) {
			dfs_f(e[i].to, now);
			for (int j = 1; j <= p; j++) {//把两个从 now 出发的拼起来
				fp[now] += (j == 0 ? 1 : f[e[i].to][j - 1]) * f[now][p - j];
			}
			for (int j = 1; j <= q; j++) {
				fq[now] += (j == 0 ? 1 : f[e[i].to][j - 1]) * f[now][q - j];
			}
			for (int j = 1; j < n; j++)
				f[now][j] += f[e[i].to][j - 1];
		}
}

void dfs_g(int now, int father) {
	for (int i = 0; i < n; i++)
		g[now][i] = nw[now][i] - f[now][i];
	g[now][0] = 1;
	for (int i = le[now]; i; i = e[i].nxt)
		if (e[i].to != father) {
			for (int j = 1; j < n; j++)//换根
				nw[now][j] -= nw[e[i].to][j - 1];
			for (int j = 1; j < n; j++)
				nw[e[i].to][j] += nw[now][j - 1];
			dfs_g(e[i].to, now);
			for (int j = 1; j < n; j++)//换回来
				nw[e[i].to][j] -= nw[now][j - 1];
			for (int j = 1; j < n; j++)
				nw[now][j] += nw[e[i].to][j - 1];
		}
	for (int i = 1; i <= p; i++)//计算(两个都是从 now 出发,一个向 now 子树,一个往外)
		gp[now] += g[now][i] * f[now][p - i];
	for (int i = 1; i <= q; i++)
		gq[now] += g[now][i] * f[now][q - i];
}

int main() {
//	freopen("intersection.in", "r", stdin);
//	freopen("intersection.out", "w", stdout);
	
	scanf("%d %d %d", &n, &p, &q);
	for (int i = 1; i < n; i++) {
		scanf("%d %d", &x, &y);
		add(x, y);
	}
	
	dfs_f(1, 0);
	for (int i = 1; i <= n; i++)
		for (int j = 0; j < n; j++)
			nw[i][j] = f[i][j];
	dfs_g(1, 0);
	
	ll lsum = 0, rsum = 0;//按上面进行计算
	for (int i = 1; i <= n; i++)
		lsum += fp[i], rsum += fq[i];
	ans = lsum * rsum;
	for (int i = 1; i <= n; i++)
		ans -= fp[i] * fq[i] + fp[i] * gq[i] + gp[i] * fq[i];
	
	printf("%lld", ans * 4ll);//记得要乘4(ab可以互换,cd可以互换)
	
	return 0;
}
原文地址:https://www.cnblogs.com/Sakura-TJH/p/YBT_GXJJ_21165.html