●(考试失误导致的)倍增总结

  • 个人理解:按照翻倍的形式处理数据,使得加速某些询问过程
  • 具体应用:
    • 快速幂:O(log2n)

      计算ab的值,通常是枚举b个a相乘,但若b很大时,则需要用到快速幂。

      举一个栗子来说明快速幂的做法:

      若要求2160的值,则把60看成二进制数111100,

      即我们要求得21111100(2)的值,

      又因为基于ab1×ab2×ab3×……abn = ab1+b2+b3+……+bn

      所以21111100(2)可以看成是21000100(2)×21001000(2)×21010000(2)

                                                  ×21100000(2)

      可以直接从低到高枚举b(即60)的二进制数每一位计算出这一位为1的二进制数对应的幂值,用倍增的思想,快速求出a的2k次方:

      如a1000(2)=a100(2)×a100(2)

      若b的该位为1,则把该幂值乘进答案。

       

    • 代码:

      const int P=100007;
      int pow(int a,int b) //求a^b,对P取模
      {
          int ans=1; a%=P;
          while (b)
          {
              if (b&1) ans=1ll*ans*a%P;
              b>>=1;
              a=1ll*a*a%mo;
          }
          return ans;
      }
    • RMQ (Range Minimum/Maximum Query) 用倍增算法求ST表:

      预处理 O(n log2 n) 查询 O(1)

      举个栗子:给出一个序列,多次询问区间[ l , r ]中的最小值

      用st[i][j],表示以i位置作为结尾,包含该位在内,前2j个元素的最小值

      rmq-st

      通过如下代码,用倍增思想,递推出ST表(包含询问代码)

      #include<iostream>
      using namespace std;
      int n,a,b;
      int st[10005][15];
      int log2[10005];
      void make_ST(){
          for(int j=1;(1<<j)<=n;j++)
              for(int i=(1<<j);i<=n;i++)
                  st[i][j]=min(st[i][j-1],st[i-(1<<(j-1))][j-1]);
      }
      int query(int l,int r){
          int k=log2[r-l+1];
          return min(st[l+(1<<k)-1][k],st[r][k]);
      }
      int main()
      {
          scanf("%d",&n);
          log2[1]=0;;
          for(int i=2;i<=n;i++) log2[i]=log2[i>>1]+1;
          for(int i=1;i<=n;i++) scanf("%d",&st[i][0]);
          make_ST();
          while(scanf("%d%d",&a,&b)!=EOF) printf("%d
      ",query(a,b));
          return 0;
      }
    • 树上倍增法求LCA:预处理O(n log2 n) 查询 O(log2 n)

      同样的,用一个fa[i][j]表示从i节点向上走2j个单位到哪个节点,代码:

      void dfs(int u,int father){         //从根结点u开始dfs,u的父亲是fa
          d[u]=d[fa]+1;                     //u的深度为它父亲的深度+1
          fa[u][0]=father;                 //u向上走2^0步到达的结点是其父亲
          for(int i=1;(1<<i)<=d[u];i++) 
              fa[u][i]=fa[fa[u][i-1]][i-1];    //预处理fa时,保证能从u走到fa[u][i]
          for(int i=head[u];i!=-1;i=e[i].next){//对于u的每个儿子 
              int v=e[i].v; 
              if(v!=father)dfs(v,u);
          }                                 //递归处理以v为根结点的子树} 
      } 
      int lca(int a,int b)
      {
          if(d[a] > d[b]) swap(a , b) ;     //保证a点在b点上面
          for(int j = 20 ; j >= 0 ; j--)     //将b点向上移动到和a的同一深度
              if(d[a] <= d[b] - (1<<j)) b = fa[b][j] ;
          if(a == b) return a ;            //如果a和b相遇
          for(int j = 20 ; j >= 0 ; j--){    //a,b同时向上移动
              if(fa[a][j] == fa[b][j]) continue ;//如果a,b的2^j祖先相同,则不移动
              a = fa[a][j] , b = fa[b][j] ;//否则同时移动到2^j处
          }
           return fa[a][0] ;                //返回最后a的父亲结点
      }
    • 倍增求ST表,RMQ求LCA:预处理O(n log2 n) 查询 O(1)

      把树转成序列,在序列中查询两点的lca

      对数进行dfs遍历,遇到一个点(即使以前遇到过),就把它加入一个序列数组

      同时记录下每个节点在序列中第一次出现的位置

      lca-st

      不难发现,对于询问的两个点(a,b且假设dfs时a比b先被遍历),在dfs时,从a遍历到b的过程中,一定会经过他们的LCA,即在遍历序列中,它们的LCA一定被记录在了这两个点之间:

      举个栗子:比如要求4号节点和8号节点的LCA,则可以发现在遍历序列中,in[4]位置到in[8]位置之间,出现里一次它们的LCA:1 号节点

      那该如何找到它们的LCA呢。

      可以发现,在深度序列中,in[4]位置到in[8]位置之间,最小值(为1)对应的即是遍历序列中的它们的LCA:1 号节点

      结论:两个节点在深度序列中对应的区间的最小值即是它们的LCA的深度

      有了“区间最小值”这一东西,之前的ST表,RMQ就可以拿来用了

      /*
         stm[i][j]:在深度序列中,以i位置作为结尾,包含该位在内,前2^j个元素的最小值(最小深度) 
          stu[i][j]:stm[i][j]记录的最小值对应的节点 
      */
      void dfs(int u,int fa,int dep){
          deep[u]=deep; 
          stm[++cnt][0]=deep;    stu[cnt][0]=u;    in[u]=cnt;
          for(int i=head[u];i;i=e[i].next){
              int v=e[i].to;
              if(v==fa) continue;
              dfs(v,u,deep+1);
              stm[++cnt][0]=deep;    stu[cnt][0]=u;
          }
      }
      void make_ST(){
          for(int j=1;(1<<j)<=cnt;j++)
              for(int i=(1<<j);i<=cnt;i++)
                  if(stm[i][j-1]<stm[i-(1<<(j-1))][j-1])
                      stm[i][j]=stm[i][j-1],stu[i][j]=stu[i][j-1];
                  else stm[i][j]=stm[i-(1<<(j-1))][j-1],stu[i][j]=stu[i-(1<<(j-1))][j-1];
      }
      
      int dis(int x,int y){
          int l,r,k,lca;
          l=min(in[x],in[y]);
          r=max(in[x],in[y]);
          k=log2[r-l+1];
          if(stm[l+(1<<k)-1][k]<stm[r][k]) lca=stu[l+(1<<k)-1][k];
          else lca=stu[r][k];
      }
    • 后缀数组倍增算法 O(n log2n)
    • 突然想起,补上模板。

void build(int n,int m)
{
    int *x=wa,*y=wy,i;
    for(i=0;i<m;i++) c[i]=0;
    for(i=0;i<n;i++) c[x[i]=s[i]]++;
    for(i=1;i<m;i++) c[i]+=c[i-1];
    for(i=n-1;i>=0;i--) sa[--c[x[i]]]=i;
    for(int k=1;k<=n;k<<=1)
    {
        int p=0;
        for(i=n-k;i<n;i++) y[p++]=i;
        for(i=0;i<n;i++) if(sa[i]>=k) y[p++]=sa[i]-k;
        for(i=0;i<m;i++) c[i]=0;
        for(i=0;i<n;i++) c[x[y[i]]]++;
        for(i=1;i<m;i++) c[i]+=c[i-1];
        for(i=n-1;i>=0;i--) sa[--c[x[y[i]]]]=y[i];
        m=1; swap(x,y); x[sa[0]]=0;
        for(i=1;i<n;i++)
            x[sa[i]]=cmp(y,i,k,n)?m-1:m++;
        if(m>=n) break;
    }
    for(i=0;i<n;i++) rank[sa[i]]=i;
    int h=0;
    for(i=0;i<n;i++)
    {
        if(h) h--;
        if(rank[i]==0) continue;
        int j=sa[rank[i]-1];
        while(s[i+h]==s[j+h]) h++;
        height[rank[i]]=h;
    }
}
  • 某些用到倍增的题目 
    • NOIP 2012 开车旅行 codevs 1199
    • 暴力的话是O(n2+nm),由于决策单一(即某人从某点出发到下一点这一过程是唯一确定的),可以进行倍增加速。
    • 由于是两人交替走,比一般的路径倍增要麻烦一点
    • 先借助set预处理出两个人分别从i号点向前走的下一个点是哪个以及走的距离。
    • 然后用to[i][j]表示从i号点出发,走2j轮(一轮为小A先走,小B再走)到达的目的地。用dis[i][j][0/1](0:小B,1:小A)与上面的to数组对应,即分别表示从i号点出发,走2j轮,小B/小A走过的距离和。

      这样通过倍增后,加速了答案的寻找过程

      将时间复杂度优化为了O(n log2n+m log2n)

      更加详细的大佬题解-------------------------------------------------------->

      代码:

      #include<set> 
      #include<cstdio>
      #include<cstring>
      #include<iostream>
      #define INF 0x3f3f3f3f
      #define eps 0.000003
      #define MAXN 100005
      #define siter set<info>::iterator
      #define info(a,b) (info){a,b}
      using namespace std;//0 小B        1 小A 
      struct info{
          int h,p;
          bool operator <(const info rtm) const{
              return h<rtm.h;
          }
      };
      int to[MAXN][25],dis[MAXN][25][2],des[MAXN][2],len[MAXN][2],ans[2];
      int he[MAXN];
      int n,m,x,st;
      set<info>s;
      int sign(double i){
          if(-eps<=i&&i<=eps) return 0;
          if(i<-eps) return -1;
          return 1;
      }
      int distant(int i,int j){
          return he[i]>he[j]?he[i]-he[j]:he[j]-he[i];
      }
      void update(int i,info p){
          int j=p.p,d=distant(i,j);
          if(d<len[i][0]||(d==len[i][0]&&he[j]<he[des[i][0]])){
              len[i][1]=len[i][0];des[i][1]=des[i][0];
              len[i][0]=d;des[i][0]=j;     
          }
          else if(d<len[i][1]||(d==len[i][1]&&he[j]<he[des[i][1]])){
              len[i][1]=d;des[i][1]=j;
          }
      }
      void drive(int i,int v){
          for(int j=20;j>=0;j--) if(dis[i][j][0]+dis[i][j][1]<=v&&to[i][j]){
              ans[0]+=dis[i][j][0];
              ans[1]+=dis[i][j][1];
              v=v-dis[i][j][0]-dis[i][j][1];
              i=to[i][j];
          }
          if(len[i][1]<=v&&des[i][1]) ans[1]+=len[i][1];
      }
      int main()
      {
          scanf("%d",&n);
          for(int i=1;i<=n;i++) scanf("%d",&he[i]),len[i][0]=len[i][1]=INF;
          siter si;
          for(int i=n;i>=1;i--){
              s.insert(info(he[i],i));
              si=s.find(info(he[i],i));
              si++;if(si!=s.end()){
                  update(i,*si);
                  si++; if(si!=s.end()) update(i,*si); si--;
              }
              si--;if(si!=s.begin()){
                  si--; update(i,*si);
                  if(si!=s.begin()) si--,update(i,*si);
              }
          }
          for(int i=1;i<=n;i++) 
              to[i][0]=des[des[i][1]][0],
              dis[i][0][1]=len[i][1],
              dis[i][0][0]=len[des[i][1]][0];
          for(int j=1;j<=20;j++)
              for(int i=1;i<=n;i++)
                  to[i][j]=to[to[i][j-1]][j-1],
                  dis[i][j][0]=dis[i][j-1][0]+dis[to[i][j-1]][j-1][0],
                  dis[i][j][1]=dis[i][j-1][1]+dis[to[i][j-1]][j-1][1];
          scanf("%d",&x);
          double rat=1e9; int ap=0;
          for(int i=1;i<=n;i++){
              ans[0]=ans[1]=0;
              drive(i,x);
              double tmp=ans[0]? 1.0*ans[1]/ans[0]:1e9;
              if(sign(rat-tmp)==0&&he[i]>he[ap]) ap=i;
              if(sign(rat-tmp)>0) ap=i,rat=tmp;
          }
          printf("%d
      ",ap);
          scanf("%d",&m);
          for(int i=1;i<=m;i++){
              ans[0]=ans[1]=0;
              scanf("%d%d",&st,&x);
              drive(st,x);
              printf("%d %d
      ",ans[1],ans[0]);
          }
          return 0;
      }
    • NOIP 2012 疫情控制 洛谷 1084

    • 贪心+二分+倍增

    • 二分时间,check操作,将所有军队按能否到达根节点分成两类

      A类:无法在二分的时间内达到根节点。

      根据贪心策略,将这些军队移动到尽可能靠上的位置一定更优,所以把他们移动到他们所能到达的最靠近根的位置

      B类:在二分的时间内可以到达根节点。

      把他们放入一个数组,按到达根节点后剩余的时间从小到大排序

      再对树跑一个dfs,维护出根的哪些儿子节点还需要一个B类军队去驻扎,把这些儿子节点放入另一个数组,按到根的时间从小到大排序

      进行贪心,尝试用B类军队去覆盖没有还需要被驻扎的(根的儿子)节点:

      对于从小到大枚举到的某一个B类军队,首先判断他到根节点时进过的那个根的儿子节点是否被驻扎,若没有,则直接去驻扎那个节点。若已被驻扎,则尝试去驻扎从小到大枚举到的还需要被驻扎的第一个节点。(有一点绕,好好理解一下,正确性很容易证明)

      最后判断该时间下,那些还需要被驻扎的(根的儿子)节点是否被驻扎完。

      至于倍增用在哪里,显而易见,在将军队向上移动时,不可能一个一个地向father移动,所以倍增一下,加速移动过程。

      代码:

      洛谷和Vijos上过了,但Codevs和Tyvj上却WA了一个点,在Tyvj上把数据下了下来,手测却发现输出是正确的……

      不明原因,非常绝望,望有心人能解答疑难。

      #include<cstdio>
      #include<cstring>
      #include<iostream>
      #include<algorithm>
      #define ll long long
      #define MAXN 50005
      using namespace std;
      struct edge{
          int to,next;
          ll val;
      }e[MAXN*2];
      struct node{
          int id; ll val;
          bool operator<(const node &rtm) const{
              return val<rtm.val;
          }
      }ar[MAXN],ne[MAXN];
      ll stt[MAXN][20];
      int stu[MAXN][20];
      int p[MAXN],from[MAXN],head[MAXN];
      bool vis[MAXN];
      ll l,r,mid,ans;
      int n,m,ent=1,rs,cnt,nnt;
      void add(int u,int v,int w){
          e[ent]=(edge){v,head[u],1ll*w};
          head[u]=ent++;
      }
      void dfs(int u,int fa,ll dis,int fr){
          if(fa==1) rs++;
          stu[u][0]=fa;
          stt[u][0]=dis;
          if(fa==1) from[u]=u;
          else from[u]=fr;
          for(int j=1;j<=16;j++){
              stu[u][j]=stu[stu[u][j-1]][j-1];
              stt[u][j]=stt[u][j-1]+stt[stu[u][j-1]][j-1];
          }
          for(int i=head[u];i;i=e[i].next){
              int v=e[i].to;
              if(v==fa) continue;
              if(u==1) dfs(v,u,e[i].val,v);
              else dfs(v,u,e[i].val,fr);
          }
      }
      void update(int u,int fa){
          bool fg=1,fl=0;
          for(int i=head[u];i;i=e[i].next){
              int v=e[i].to;
              if(v==fa) continue;
              fl=1;
              update(v,u);
              if(!vis[v]) fg=0;
              if(u==1&&!vis[v]) ne[++nnt]=(node){v,e[i].val};
          }
          if(fl) vis[u]=fg|vis[u];
      }
      bool check(ll x){
          ll tmp;int u;
          cnt=0; nnt=0;
          memset(vis,0,sizeof(vis));vis[0]=1;
          for(int i=1;i<=m;i++){
              tmp=x; u=p[i];
              for(int j=16;j>=0;j--)if(stu[u][j]&&tmp>=stt[u][j]){
                  tmp-=stt[u][j];
                  u=stu[u][j];
              }
              if(u==1) ar[++cnt]=(node){p[i],tmp};
              else vis[u]=1;
          }
          update(1,0);
          sort(ne+1,ne+nnt+1);
          sort(ar+1,ar+cnt+1);
          int pp=1,res=nnt;
          for(int i=1;i<=cnt;i++){
              while(vis[ne[pp].id]) pp++;
              if(!vis[from[ar[i].id]]){
                  vis[from[ar[i].id]]=1;
                  res--;
              }
              else{
                  if(ar[i].val>=ne[pp].val){
                      vis[ne[pp].id]=1;
                      res--;
                  }
              }
              if(!res) return 1;
          }
          return 0;
      }
      void Binary(){
          while(l<=r){
              mid=(l+r)/2;
              if(check(mid)) ans=mid,r=mid-1;
              else l=mid+1;
          }
          printf("%lld",ans);
      } 
      int main()
      {
          scanf("%d",&n);
          l=1; r=0;
          for(int i=1,a,b,c;i<n;i++){
              scanf("%d%d%d",&a,&b,&c);
              add(a,b,c); add(b,a,c);
              r+=1ll*c;
          }
          dfs(1,0,0,0);
          scanf("%d",&m);
          for(int i=1;i<=m;i++)
              scanf("%d",&p[i]);
          if(m<rs) printf("-1
      ");
          else Binary();
          return 0;
      }
  • 总结:

    由NOIP2012的两道题来看(一年考的两个题都涉及到倍增,而且代码都这么恶心……),倍增多半不会直接考,而是在解题时需要用到上文中的某几个应用来优化,或者说是加速某些答案(或者中间答案)的寻找过程。

    倍增是一种十分常用且实用的技巧,特别是那几个应用,一定要掌握到位。

原文地址:https://www.cnblogs.com/zj75211/p/7568880.html