[HDU4812]D Tree

vjudge
题意:给一棵树,每个点上有一个权值,求一条路径使得路径上权值的乘积膜(10^6+3)的结果为(K),输出路径的两个端点(x,y)。如有多解,设(x<y),输出(x)最小的,若仍有多解输出(y)最小的。

sol

点分。
每次考虑所有过重心的路径,开一个桶(T[x])表示到根路径权值乘积(不算根的权值)为(x)的最小节点编号。
注意要先查出所有点到根的权值乘积,全部更新答案,再去更新桶(T)
更新答案的时候用逆元。逆元可以线性预处理出来。
记得要设(T[1]=u),做完这一层之后也要把(T[1])清空

code

#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
int gi()
{
	int x=0,w=1;char ch=getchar();
	while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
	if (ch=='-') w=0,ch=getchar();
	while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
	return w?x:-x;
}
const int N = 1e5+5;
const int mod = 1e6+3;
int n,k,inv[mod],val[N],to[N<<1],nxt[N<<1],head[N],cnt;
int sz[N],w[N],root,sum,vis[N],T[mod],dep[N],tmp[N],top,ans1,ans2;
void link(int u,int v){to[++cnt]=v;nxt[cnt]=head[u];head[u]=cnt;}
void getroot(int u,int f)
{
	sz[u]=1;w[u]=0;
	for (int e=head[u];e;e=nxt[e])
	{
		int v=to[e];if (v==f||vis[v]) continue;
		getroot(v,u);
		sz[u]+=sz[v];w[u]=max(w[u],sz[v]);
	}
	w[u]=max(w[u],sum-sz[u]);
	if (w[u]<w[root]) root=u;
}
void getdeep(int u,int f,int sta)
{
	dep[u]=sta;tmp[++top]=u;
	for (int e=head[u];e;e=nxt[e])
	{
		int v=to[e];if (v==f||vis[v]) continue;
		getdeep(v,u,1ll*sta*val[v]%mod);
	}
}
void solve(int u)
{
	vis[u]=1;T[1]=u;
	for (int e=head[u];e;e=nxt[e])
	{
		int v=to[e];if (vis[v]) continue;
		top=0;getdeep(v,0,val[v]);
		for (int i=1;i<=top;++i)
		{
			int xx=1ll*dep[tmp[i]]*val[u]%mod,yy=1ll*k*inv[xx]%mod;
			int x=tmp[i],y=T[1ll*k*inv[1ll*dep[tmp[i]]*val[u]%mod]%mod];
			if (!y) continue;
			if (x>y) swap(x,y);
			if (x<ans1||(x==ans1&&y<ans2)) ans1=x,ans2=y;
		}
		for (int i=1;i<=top;++i) if (!T[dep[tmp[i]]]||tmp[i]<T[dep[tmp[i]]]) T[dep[tmp[i]]]=tmp[i];
	}
	for (int e=head[u];e;e=nxt[e])
	{
		int v=to[e];if (vis[v]) continue;
		top=0;getdeep(v,0,val[v]);
		for (int i=1;i<=top;++i) T[dep[tmp[i]]]=0;
	}
	T[1]=0;
	for (int e=head[u];e;e=nxt[e])
	{
		int v=to[e];if (vis[v]) continue;
		sum=sz[v];root=0;
		getroot(v,0);
		solve(root);
	}
}
int main()
{
	inv[1]=1;
	for (int i=2;i<mod;++i) inv[i]=mod-1ll*(mod/i)*inv[mod%i]%mod;
	while (scanf("%d %d",&n,&k)!=EOF)
	{
		memset(head,0,sizeof(head));cnt=0;
		memset(vis,0,sizeof(vis));ans1=ans2=1e9;
		for (int i=1;i<=n;++i) val[i]=gi();
		for (int i=1;i<n;++i)
		{
			int u=gi(),v=gi();
			link(u,v);link(v,u);
		}
		root=0;sum=w[0]=n;
		getroot(1,0);
		solve(root);
		if (ans1==1e9) puts("No solution");
		else printf("%d %d
",ans1,ans2);
	}
}
原文地址:https://www.cnblogs.com/zhoushuyu/p/8473132.html