CF1260F

题目大意

一棵树,每个节点的权为L[i]~R[i],一棵树的贡献为(sumlimits_{h_{i} = h_{j}, 1 le i < j le n}{dis(i,j)}),其中(dis(i,j))表示i到j路径上的边数

(prodlimits_{i = 1}^{n} (r_{i} - l_{i} + 1))种不同取值的情况的贡献和

题解

一眼点分治

把一个点的子树上除该点的乘积加到对应区间上,查找就直接找对应区间

设S[i]表示(prod_{i eq j}{R_i-L_i+1}),W[i]表示(frac{1}{R_i-L_i+1})

那么在b点查询a时一种方案的贡献为(S[a]*W[b]*(dis[a]+dis[b])),拆开后为(W[b]*S[a]*dis[a]+W[b]*dis[b]*S[a]),维护S[a]*dis[a]与S[a]的和即可

线段树常数较大,所以要用树状数组

code

#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#define fo(a,b,c) for (a=b; a<=c; a++)
#define fd(a,b,c) for (a=b; a>=c; a--)
#define add(a,b) a=((a)+(b))%1000000007
#define max(a,b) (a>b?a:b)
#define low(x) (x&-(x))
#define mod 1000000007
#define Mod 1000000005
#define N 100000
using namespace std;

int a[200001][2];
int tr[100001][4];
bool Tr[100001];
int d[100001];
int ls[100001];
int L[100001];
int R[100001];
int size[100001];
bool bz[100001];
int Fa[100001];
long long S[100001];
long long w[100001];
int n,i,j,k,l,len,mn1,mn2,SIZE,tot;
long long sum,ans,W,D,WD,find1,find2;

void New(int x,int y)
{
	++len;
	a[len][0]=y;
	a[len][1]=ls[x];
	ls[x]=len;
}

long long qpower(long long a,int b)
{
	long long ans=1;
	
	while (b)
	{
		if (b&1)
		ans=ans*a%mod;
		
		a=a*a%mod;
		b>>=1;
	}
	
	return ans;
}

void Change(int t,long long s1,long long s2)
{
	long long S1=-s1*(t-1)%mod,S2=-s2*(t-1)%mod;
	
	while (t<=N)
	{
		if (!Tr[t])
		Tr[t]=1,d[++tot]=t;
		
		add(tr[t][0],s1);
		add(tr[t][1],S1);
		add(tr[t][2],s2);
		add(tr[t][3],S2);
		
		t+=low(t);
	}
}
void change(int l,int r,long long s1,long long s2)
{
	Change(l,s1,s2);
	Change(r+1,-s1,-s2);
}

void Find(int t,int type)
{
	int T=t;
	
	long long s1,S1,s2,S2;
	s1=S1=s2=S2=0;
	
	while (t)
	{
		add(s1,tr[t][0]);
		add(S1,tr[t][1]);
		add(s2,tr[t][2]);
		add(S2,tr[t][3]);
		
		t-=low(t);
	}
	
	add(find1,(s1*T+S1)%mod*type);
	add(find2,(s2*T+S2)%mod*type);
}
void find(int l,int r)
{
	find1=find2=0;
	Find(r,1);
	Find(l-1,-1);
}

void dfs1(int fa,int t)
{
	int i,mx=0;
	
	Fa[t]=fa;
	size[t]=1;
	
	for (i=ls[t]; i; i=a[i][1])
	if (!bz[a[i][0]] && a[i][0]!=fa)
	{
		dfs1(t,a[i][0]);
		
		size[t]+=size[a[i][0]];
		mx=max(mx,size[a[i][0]]);
	}
	mx=max(mx,SIZE-size[t]);
	
	if (mx<mn1)
	mn1=mx,mn2=t;
}

void dfs2(int fa,int t,int dis)
{
	int i;
	
	W=w[t];D=dis;WD=W*D%mod;
	
	find(L[t],R[t]);
	
	add(ans,WD*find1+W*find2);
	
	for (i=ls[t]; i; i=a[i][1])
	if (!bz[a[i][0]] && a[i][0]!=fa)
	dfs2(t,a[i][0],dis+1);
}

void dfs3(int fa,int t,int dis)
{
	int i;
	
	change(L[t],R[t],S[t],S[t]*dis%mod);
	
	for (i=ls[t]; i; i=a[i][1])
	if (!bz[a[i][0]] && a[i][0]!=fa)
	dfs3(t,a[i][0],dis+1);
}

void work(int t,int Size)
{
	int i;
	
	SIZE=Size;
	mn1=Size;
	mn2=t;
	
	dfs1(0,t);
	t=mn2;
	bz[t]=1;
	
//	---
	
	tot=0;
	change(L[t],R[t],S[t],0);
	
	for (i=ls[t]; i; i=a[i][1])
	if (!bz[a[i][0]])
	{
		dfs2(t,a[i][0],1);
		dfs3(t,a[i][0],1);
	}
	
	fo(i,1,tot)
	tr[d[i]][0]=tr[d[i]][1]=tr[d[i]][2]=tr[d[i]][3]=Tr[d[i]]=0;
	
//	---
	
	for (i=ls[t]; i; i=a[i][1])
	if (!bz[a[i][0]] && a[i][0]!=Fa[t])
	work(a[i][0],size[a[i][0]]);
	
	if (Fa[t])
	work(Fa[t],Size-size[t]);
}

int main()
{
//	freopen("f.in","r",stdin);
//	freopen("b.out","w",stdout);
	
	sum=1;
	
	scanf("%d",&n);
	fo(i,1,n)
	scanf("%d%d",&L[i],&R[i]),sum=(sum*(R[i]-L[i]+1))%mod,w[i]=qpower(R[i]-L[i]+1,Mod);
	fo(i,2,n)
	{
		scanf("%d%d",&j,&k);
		
		New(j,k);
		New(k,j);
	}
	
	fo(i,1,n)
	S[i]=sum*w[i]%mod;
	
	work(1,n);
	
	printf("%I64d
",(ans+mod)%mod);
}
原文地址:https://www.cnblogs.com/gmh77/p/11953323.html