【BZOJ4821】[Sdoi2017]相关分析 线段树

【BZOJ4821】[Sdoi2017]相关分析

Description

Frank对天文学非常感兴趣,他经常用望远镜看星星,同时记录下它们的信息,比如亮度、颜色等等,进而估算出星星的距离,半径等等。Frank不仅喜欢观测,还喜欢分析观测到的数据。他经常分析两个参数之间(比如亮度和半径)是否存在某种关系。现在Frank要分析参数X与Y之间的关系。他有n组观测数据,第i组观测数据记录了x_i和y_i。他需要一下几种操作1 L,R:用直线拟合第L组到底R组观测数据。用xx表示这些观测数据中x的平均数,用yy表示这些观测数据中y的平均数,即
xx=Σx_i/(R-L+1)(L<=i<=R)
yy=Σy_i/(R-L+1)(L<=i<=R)
如果直线方程是y=ax+b,那么a应当这样计算:
a=(Σ(x_i-xx)(y_i-yy))/(Σ(x_i-xx)(x_i-xx)) (L<=i<=R)
你需要帮助Frank计算a。
2 L,R,S,T:
Frank发现测量数据第L组到底R组数据有误差,对每个i满足L <= i <= R,x_i需要加上S,y_i需要加上T。
3 L,R,S,T:
Frank发现第L组到第R组数据需要修改,对于每个i满足L <= i <= R,x_i需要修改为(S+i),y_i需要修改为(T+i)。

Input

第一行两个数n,m,表示观测数据组数和操作次数。
接下来一行n个数,第i个数是x_i。
接下来一行n个数,第i个数是y_i。
接下来m行,表示操作,格式见题目描述。
1<=n,m<=10^5,0<=|S|,|T|,|x_i|,|y_i|<=10^5
保证1操作不会出现分母为0的情况。

Output

对于每个1操作,输出一行,表示直线斜率a。
选手输出与标准输出的绝对误差不超过10^-5即为正确。

Sample Input

3 5
1 2 3
1 2 3
1 1 3
2 2 3 -3 2
1 1 2
3 1 2 2 1
1 1 3

Sample Output

1.0000000000
-1.5000000000
-0.6153846154

题解:显然直接用线段树,需要维护一下几个东西:x之和,x^2之和,y之和,x*y之和。并且同时还要支持区间赋值和区间加。没有细节,就是讨论。

这里只说如何处理x*y吧,对于区间加,$sum(x+v)(y+v)=sum x*y+v*sum (x+y)+v*v*n$。对于区间赋值,$sumlimits_{i=1}^n(a+i)(b+i)=sumlimits_{i=1}^nab+(a+b)sumlimits_{i=1}^ni+sumlimits_{i=1}^ni^2$。

#include <cstdio>
#include <cstring>
#include <iostream>
#define ls t<<1
#define rs t<<1|1
#define S2(_) ((_)*((_)+1)*(2*(_)+1)/6)
using namespace std;
typedef double db;
const int maxn=100010;
int n,m;
db X,Y,XX,YY,XY;
db x[maxn<<2],y[maxn<<2],xx[maxn<<2],yy[maxn<<2],xy[maxn<<2],xa[maxn],ya[maxn],sx[maxn<<2],sy[maxn<<2];
db tx[maxn<<2],ty[maxn<<2];
db ans1,ans2,_X,_Y;
inline void pushup(int t)
{
	x[t]=x[ls]+x[rs],y[t]=y[ls]+y[rs],xx[t]=xx[ls]+xx[rs],yy[t]=yy[ls]+yy[rs],xy[t]=xy[ls]+xy[rs];
}
inline void pds(db siz,int t,db a,db b)
{
	xx[t]+=2*a*x[t]+a*a*siz,yy[t]=2*b*y[t]+b*b*siz,xy[t]=xy[t]+b*x[t]+a*y[t]+a*b*siz,x[t]+=a*siz,y[t]+=b*siz;
	if(tx[t]!=1e15)	tx[t]+=a,ty[t]+=b;
	else	sx[t]+=a,sy[t]+=b;
}
inline void pdt(int l,int r,int t,db a,db b)
{
	db siz=r-l+1;
	x[t]=(a+l+a+r)*siz/2,y[t]=(b+l+b+r)*siz/2,xx[t]=S2(a+r)-S2(a+l-1),yy[t]=S2(b+r)-S2(b+l-1);
	xy[t]=(a+l-1)*(b+l-1)*siz+(a+l-1+b+l-1)*siz*(siz+1)/2+S2(siz);
	sx[t]=sy[t]=0,tx[t]=a,ty[t]=b;
}
inline void pushdown(int l,int r,int t)
{
	int mid=(l+r)>>1;
	if(sx[t]||sy[t])	pds(mid-l+1,ls,sx[t],sy[t]),pds(r-mid,rs,sx[t],sy[t]),sx[t]=sy[t]=0;
	if(tx[t]!=1e15)	pdt(l,mid,ls,tx[t],ty[t]),pdt(mid+1,r,rs,tx[t],ty[t]),tx[t]=ty[t]=1e15;
}
void build(int l,int r,int t)
{
	tx[t]=ty[t]=1e15;
	if(l==r)
	{
		x[t]=xa[l],y[t]=ya[l],xx[t]=x[t]*x[t],yy[t]=y[t]*y[t],xy[t]=x[t]*y[t];
		return ;
	}
	int mid=(l+r)>>1;
	build(l,mid,ls),build(mid+1,r,rs);
	pushup(t);
}
void ups(int l,int r,int t,int a,int b,db A,db B)
{
	if(a<=l&&r<=b)
	{
		pds(r-l+1,t,A,B);
		return ;
	}
	pushdown(l,r,t);
	int mid=(l+r)>>1;
	if(a<=mid)	ups(l,mid,ls,a,b,A,B);
	if(b>mid)	ups(mid+1,r,rs,a,b,A,B);
	pushup(t);
}
void upt(int l,int r,int t,int a,int b,db A,db B)
{
	if(a<=l&&r<=b)
	{
		pdt(l,r,t,A,B);
		return ;
	}
	pushdown(l,r,t);
	int mid=(l+r)>>1;
	if(a<=mid)	upt(l,mid,ls,a,b,A,B);
	if(b>mid)	upt(mid+1,r,rs,a,b,A,B);
	pushup(t);
}
void query(int l,int r,int t,int a,int b)
{
	if(a<=l&&r<=b)
	{
		X+=x[t],Y+=y[t],XX+=xx[t],YY+=yy[t],XY+=xy[t];
		return ;
	}
	pushdown(l,r,t);
	int mid=(l+r)>>1;
	if(a<=mid)	query(l,mid,ls,a,b);
	if(b>mid)	query(mid+1,r,rs,a,b);
}
inline int rd()
{
	int ret=0,f=1;	char gc=getchar();
	while(gc<'0'||gc>'9')	{if(gc=='-')	f=-f;	gc=getchar();}
	while(gc>='0'&&gc<='9')	ret=ret*10+gc-'0',gc=getchar();
	return ret*f;
}
int main()
{
	n=rd(),m=rd();
	int i,a,b,c,d,op;
	for(i=1;i<=n;i++)	xa[i]=rd();
	for(i=1;i<=n;i++)	ya[i]=rd();
	build(1,n,1);
	for(i=1;i<=m;i++)
	{
		op=rd(),a=rd(),b=rd();
		if(op==1)
		{
			c=b-a+1,X=Y=XX=YY=XY=0,query(1,n,1,a,b),_X=(db)X/c,_Y=(db)Y/c;
			ans1=XY-_X*Y-_Y*X+_X*_Y*c,ans2=XX-2*_X*X+_X*_X*c;
			printf("%.10lf
",ans1/ans2);
		}
		if(op==2)	c=rd(),d=rd(),ups(1,n,1,a,b,c,d);
		if(op==3)	c=rd(),d=rd(),upt(1,n,1,a,b,c,d);
	}
	return 0;
}
原文地址:https://www.cnblogs.com/CQzhangyu/p/7605760.html