【SSLOJ1459】空间简单度

题目


\(n\leq 3\times 10^5,K\leq 10\)

思路

考虑用总方案数减去空间简单度不超过 \(k\) 的方案数。
发现 \(k\) 很小,可以枚举所有点 \(i\),那么对于一个 \(|i-j|\leq k\) 的点 \(j\),发现这个点对贡献了路径 \(i\to j\) “两端”点的数量之积。
但是直接计算容易重复,发现每次将是 \(dfs\) 序不超过 3 个区间的点的乘积,那么求出每个点字数点的 \(dfs\) 序区间,然后扔到二维平面上,转换成求矩形面积并的问题。
扫描线+线段树即可。
时间复杂度 \(O(nk\log n)\)

代码

#pragma GCC optimize("Ofast")
#pragma GCC optimize("inline")
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N=300010,LG=20;
int n,m,tot,cnt1,cnt2,head[N],dfn[N],size[N],f[N][LG+1],dep[N],L[N],R[N];
ll ans;

struct edge
{
	int next,to;
}e[N*2];

struct node
{
	int x,l,r;
}line1[N*40],line2[N*40];

bool operator <(node x,node y)
{
	return x.x<y.x;
}

void add(int from,int to)
{
	e[++tot].to=to;
	e[tot].next=head[from];
	head[from]=tot;
}

void dfs(int x,int fa)
{
	dfn[x]=++tot; size[x]=1;
	f[x][0]=fa; dep[x]=dep[fa]+1;
	for (int i=1;i<=LG;i++)
		f[x][i]=f[f[x][i-1]][i-1];
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=fa)
		{
			dfs(v,x);
			size[x]+=size[v];
		}
	}
	L[x]=dfn[x]; R[x]=dfn[x]+size[x]-1;
}

int lca(int x,int y)
{
	if (dep[x]<dep[y]) swap(x,y);
	for (int i=LG;i>=0;i--)
		if (dep[f[x][i]]>=dep[y]) x=f[x][i];
	if (x==y) return x;
	for (int i=LG;i>=0;i--)
		if (f[x][i]!=f[y][i])
		{
			x=f[x][i];
			y=f[y][i];
		}
	return f[x][0];
}

int findson(int x,int y)
{
	for (int i=LG;i>=0;i--)
		if (dep[f[y][i]]>dep[x]) y=f[y][i];
	return y;
}

void insert(int l1,int l2,int r1,int r2)
{
	line1[++cnt1]=(node){min(l1,l2),min(r1,r2),max(r1,r2)};
	line2[++cnt2]=(node){max(l1,l2)+1,min(r1,r2),max(r1,r2)};
}

bool cmp(node x,node y)
{
	return x.x<y.x;
}

struct SegTree
{
	int l[N*4],r[N*4],sum[N*4],cnt[N*4];
	
	void build(int x,int ql,int qr)
	{
		l[x]=ql; r[x]=qr;
		if (ql==qr) return;
		int mid=(ql+qr)>>1;
		build(x*2,ql,mid); build(x*2+1,mid+1,qr);
	}
	
	void update(int x,int ql,int qr,int val)
	{
		if (l[x]==ql && r[x]==qr)
		{
			sum[x]+=val;
			if (sum[x]>0) cnt[x]=r[x]-l[x]+1;
			else if (ql==qr) cnt[x]=0;
			else cnt[x]=cnt[x*2]+cnt[x*2+1]; 
			return;
		}
		int mid=(l[x]+r[x])>>1;
		if (qr<=mid) update(x*2,ql,qr,val);
		else if (ql>mid) update(x*2+1,ql,qr,val);
		else update(x*2,ql,mid,val),update(x*2+1,mid+1,qr,val);
		if (sum[x]) cnt[x]=r[x]-l[x]+1;
			else cnt[x]=cnt[x*2]+cnt[x*2+1];
	}
}seg;

int main()
{
	int size = 256 << 20; //250M
	char*p=(char*)malloc(size) + size;
	__asm__("movl %0, %%esp\n" :: "r"(p) );
	memset(head,-1,sizeof(head));
	scanf("%d%d",&n,&m);
	for (int i=1,x,y;i<n;i++)
	{
		scanf("%d%d",&x,&y);
		add(x,y); add(y,x);
	}
	tot=0;
	dfs(1,0);
	for (int i=1;i<=n;i++)
		for (int j=i+1;j<=min(n,i+m);j++)
		{
			bool flag=0;
			if (dfn[i]>dfn[j]) swap(i,j),flag=1;
			int p=lca(i,j);
			if (p==i)
			{
				int soni=findson(i,j);
				if (L[soni]>1) insert(1,L[soni]-1,L[j],R[j]);
				if (R[soni]<n) insert(L[j],R[j],R[soni]+1,n);
			}
			else insert(L[i],R[i],L[j],R[j]);
			if (flag) swap(i,j);
		}
	seg.build(1,1,n);
	sort(line1+1,line1+1+cnt1);
	sort(line2+1,line2+1+cnt2);
	for (int i=1,j=1,k=1;i<=n;i++)
	{
		for (;line1[j].x==i && j<=cnt1;j++)
			seg.update(1,line1[j].l,line1[j].r,1);
		for (;line2[k].x==i && k<=cnt2;k++)
			seg.update(1,line2[k].l,line2[k].r,-1);
		ans+=seg.cnt[1];
	}
	printf("%lld\n",1LL*n*(n-1)/2LL-ans+n);
	return 0;
}
原文地址:https://www.cnblogs.com/stoorz/p/13477321.html