倍增并查集(萌萌哒)

传送门

说实话,这道题一开始我还真没想到是并查集。看到题目,第一个反应是暴力打标记,因为相同的一段的只需要找一次,打上标记后就意味着不会再对答案做出贡献。但这样显然是超时的对吧,所以就想着可不可以拿个啥数据结构来维护。虽然没有想出来。

先不说思想有没有什么bug,但感觉就算这样打,维护标记也是不好维护的。一整块打标记用线段树就很好实现,但是在下一条信息的时候,万一和前面的信息区间有重叠,准确找出又需要打多少标记还是有点麻烦的(应该是对于这两段都求一下是否有部分被打了标记,两者的标记合起来,再被区间长度减去才是这一次对ans的贡献,但是怎么求出两者的标记合起来覆盖了多长的区间呢?不可能是单纯的相加,它的位置是不好定位的)

所以我的代码一开始是这样的(十分暴力但还有一个点没有考虑到)

for(int i=1;i<=m;++i)
    {
        w[i].l1=read();w[i].r1=read();
        w[i].l2=read();w[i].r2=read();
        w[i].c=w[i].r1-w[i].l1+1;
        int s1=w[i].l1,s2=w[i].l2;
        while(s1<=w[i].r1&&s2<=w[i].r2)
        {
            if(!flagg[s1]&&!flagg[s2])ans++;
            flagg[s1]=1,flagg[s2]=1;
            s1++;s2++;
        }
    }
for(int i=1;i<=n;++i)
    if(!flagg[i])ans++;
wrong

然后为什么这样打连思想都是错的呢?

我对拍的时候发现了这样一组数据

8 3
6 6 2 2
3 3 7 7
6 7 3 4

输出的答案是90000,但正确的答案应该是9000。

于是发现这种思想是有bug的==

单纯的打标记是不对的,比如这组数据:我们给6,2打上标记,ans+1,再给3,7打上标记,ans+1,然后第三条信息的时候,我们不会再打标记

加上1,5,8三个没有标记的,最后的ans是5

但是这个打标记应该是有标记的“序号”的

第一条信息

我们给6,2打上标记1

第二条信息

我们给3,7打上标记2

第三条信息

我们发现6,3是一样的,那么标记1和标记2就是一样的,也就是说2,6,3,7,4都是一样的数!那么只对ans贡献1

所以最后的ans是4


再把模型抽象一下,这,这不就是并查集嘛==

朴素的思考就是每个点都建并查集,每条信息时把对应的点一个一个合并。

但时间过不了,所以这里就有一个很神奇的算法,倍增并查集,把n优化成logn就可以过了

#include<bits/stdc++.h>
#define N 100003
#define mod 1000000007
#define LL long long
using namespace std;
int read()
{
    int x=0,f=1;char s=getchar();
    while(s<'0'||s>'9'){if(s=='-')f=-1;s=getchar();}
    while(s>='0'&&s<='9'){x=x*10+s-'0';s=getchar();}
    return x*f;
}
int f[N][22];
int getfa(int x,int j)
{
    if(f[x][j]==x) return x;
    return f[x][j]=getfa(f[x][j],j);
}
void merge(int x,int y,int j)
{
    f[getfa(x,j)][j]=getfa(y,j);
}
int main()
{
    int n=read(),m=read();
    int ans=0,op=0;
    for(int i=1;i<=n;++i)
       for(int j=0;j<=20;++j)
        f[i][j]=i;//从i起始,2^j的长度的区间 
    for(int i=1;i<=m;++i)
    {
        int l1=read(),r1=read();
        int l2=read(),r2=read();
        for(int j=20;j>=0;--j)
          if(l1+(1<<j)-1<=r1){merge(l1,l2,j);l1+=(1<<j);l2+=(1<<j);}//找到最大的区间合并,再起点挪动 
    }
    for(int j=20;j>=1;--j)
    {
        for(int i=1;i+(1<<j)-1<=n;++i)//!!! 边界处理注意 
        {
            merge(i,getfa(i,j),j-1); //再一层层将区间分半合并 
            merge(i+(1<<(j-1)),getfa(i,j)+(1<<(j-1)),j-1); 
        }
    }
    int md=20;
    for(int i=1;i<=n;++i)
    {
        if(getfa(i,0)==i)ans++;
    }
    LL res=1;
    ans--;
    res=res*9%mod;
    for(int i=1;i<=ans;++i)
      res=res*10%mod;
    printf("%lld
",res);
} 
AC
原文地址:https://www.cnblogs.com/yyys-/p/11258312.html