一维二维树状数组写法总结

(有任何问题欢迎留言或私聊 && 欢迎交流讨论哦

@




一维习题:hdu1541 bzoj3211(hdu4027)
二维习题:hdu2642 1892 5517

一维树状数组:

struct FenwickTree {
    int BIT[MXN];
    int lowbit(int x) {return x&(-x);}
    void add_bit(int x, int val, int N) {for(;x <= N; x += lowbit(x)) BIT[x] += val;}
    int query_bit(int x) {int ans = 0; for(; x; x -= lowbit(x)) ans += BIT[x]; return ans;}
}bit;

改点求段:

void add(int x,int v){
  while(x <= n){
    ar[x] += v;
    x += lowbit(x);
  }
}
int query(int x){
  int sum = 0;
  while(x > 0){
    sum += ar[x];
    x -= lowbit(x);
  }
  return sum;
}
int range(int l, int r){
	return query(r) - query(l-1);
}

改段求点:

void add(int x,int v){
  while(x <= n){
    delta[x] += v;
    x += lowbit(x);
  }
}
int query(int x){
  int sum = 0;
  while(x > 0){
    sum += delta[x];
    x -= lowbit(x);
  }
  return sum;
}
void init(){
	for(int i=1;i<=n;++i){
      scanf("%d", &ar[i]);
      add(i, ar[i]-ar[i-1]);
    }
}
int get_pos(int x){
	return query(x);
}

改段求段:

here

//sum[i] = sigma(ar[x])+(i+1)*sigma(delta[x])-sigma(x*delta[x])
//delta[]是差分数组
void add(LL *a, int x, LL v){
  while(x <= n){
    a[x] += v;
    x += lowbit(x);
  }
}
LL query(LL *a, int x){
  LL sum = 0;
  while(x > 0){
    sum += a[x];
    x -= lowbit(x);
  }
  return sum;
}
void init(){
	pre[0] = 0;
    for(int i = 1; i <= n; ++i){
      scanf("%lld", &ar[i]);
      pre[i] = pre[i-1] + ar[i];
    }
}
void update(int l, int r, LL x){
	add(delta, l, x);add(delta, r+1, -x);
    add(deltai, l, l*x);add(deltai, r+1, -x*(r+1));
}
LL range(int l, int r){
	LL sum1 = pre[l-1]+l*query(delta, l-1)-query(deltai, l-1);
    LL sum2 = pre[r]+(r+1)*query(delta, r)-query(deltai, r);
    return sum2-sum1;
}

二维树状数组:

改点求段:

void add(int x, int y, int z){
  int tmp = y;
  while(x<=n){
    y = tmp;
    while(y<=n){
      cw[x][y] += z, y += lowbit(y);
    }
    x += lowbit(x);
  }
}
int query(int x, int y){
  int res = 0, tmp = y;
  while(x){
    y = tmp;
    while(y){
      res += cw[x][y], y -= lowbit(y);
    }
    x -= lowbit(x);
  }
  return res;
}

改段求点:

//d[i][j]表示 a[i][j]与a[i−1][j]+a[i][j−1]−a[i−1][j−1]的差
//delta[][]是差分数组
void add(int x, int y, int z){
  int tmp = y;
  while(x <= n){
    y = tmp;
    while(y <= n){
      delta[x][y] += z, y += lowbit(y);
    }
    x += lowbit(x);
  }
}
void update(int xa,int ya,int xb,int yb,int z){
  add(xa,ya,z);add(xa,yb+1,-z);add(xb+1,ya,-z);add(xb+1,yb+1,z);
}
int query(int x, int y){
  int res = 0, tmp = y;
  while(x){
    y = tmp;
    while(y){
      res += delta[x][y], y -= lowbit(y);
    }
    x -= lowbit(x);
  }
  return res;
}
void init(){
	for(int i = 1; i <= n; ++i){
      for(int j = 1; j <= n; ++j){
        int tmp = ar[i][j]-ar[i-1][j]-ar[i][j-1]+ar[i-1][j-1];
        add(i,j,tmp);
      }
    }
}

改段求段:

sum[x][y] = (x+1)(y+1) (Sigma) (d[i][j]) - (y+1)(Sigma)(id[i][j]) - (x+1)(Sigma)(jd[i][j]) + (Sigma)(ij*d[i][j])

//sum[x][y] = (x+1)(y+1)sigma(d[i][j])-(y+1)sigma(i*d[i][j])-(x+1)sigma(j*d[i][j])+sigma(i*j*d[i][j])

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<assert.h>
#include<bitset>
#define lson rt<<1
#define rson rt<<1|1
#define lowbit(x) (x)&(-(x))
#define all(x) (x).begin(),(x).end()
using namespace std;
typedef long long LL;
const int INF = 0x3f3f3f3f;
const int N = (int)1e3 +107;
int ar[N][N], da[N][N], di[N][N], dj[N][N],dij[N][N];
int n, m, q;
//sumxy = (x+1)(y+1)sigma(d[i][j])-(y+1)sigma(i*d[i][j])-(x+1)sigma(j*d[i][j])+sigma(i*j*d[i][j])
void add(int x, int y, int z){
  for(int i=x;i<=n;i+=lowbit(i)){
    for(int j=y;j<=n;j+=lowbit(j)){
      da[i][j] += z; di[i][j] += z*x; dj[i][j] += z*y; dij[i][j] += z*x*y;
    }
  }
}
void update(int xa,int ya,int xb,int yb,int z){
  add(xa,ya,z);add(xa,yb+1,-z);add(xb+1,ya,-z);add(xb+1,yb+1,z);
}
int query(int x, int y){
  int res = 0;
  for(int i = x; i>0; i -= lowbit(i)){
    for(int j = y; j>0; j -= lowbit(j)){
      res += (x+1)*(y+1)*da[i][j] - (y+1)*di[i][j] - (x+1)*dj[i][j] + dij[i][j];
    }
  }
  return res;
}
int ask(int xa,int ya,int xb,int yb){
  return query(xb,yb)-query(xb,ya-1)-query(xa-1,yb)+query(xa-1,ya-1);
}
void init(){
  for(int i = 1; i <= n; ++i){
    for(int j = 1; j <= n; ++j){
      int tmp = ar[i][j]-ar[i-1][j]-ar[i][j-1]+ar[i-1][j-1];
      add(i,j,tmp);
      //update(i,j,i,j,ar[i][j]);
    }
  }
}
int main(){
  while(~scanf("%d", &n)){
    memset(ar,0,sizeof(ar));
    for(int i=1;i<=n;++i){
      for(int j=1;j<=n;++j){
        scanf("%d",&ar[i][j]);
      }
    }
    init();
    scanf("%d",&q);
    while(q--){
      int op,xa,xb,ya,yb,c;
      scanf("%d%d%d%d%d",&op,&xa,&ya,&xb,&yb);
      if(op==1){
        scanf("%d",&c);
        update(xa,ya,xb,yb,c);
      }else{
        printf("%d
", ask(xa,ya,xb,yb));
      }
      if(q<=0)break;
    }
  }
  return 0;
}

习题答案:

HDU1892

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<assert.h>
#include<bitset>
#define lson rt<<1
#define rson rt<<1|1
#define lowbit(x) (x)&(-(x))
#define all(x) (x).begin(),(x).end()
using namespace std;
typedef long long LL;
const int INF = 0x3f3f3f3f;
const int N = (int)1e3 +107;
int ar[N][N], da[N][N], di[N][N], dj[N][N],dij[N][N];
int n, m, q;
//sumxy = (x+1)(y+1)sigma(d[i][j])-(y+1)sigma(i*d[i][j])-(x+1)sigma(j*d[i][j])+sigma(i*j*d[i][j])
void add(int x,int y,int c){
  for(int i=x;i<=n;i+=lowbit(i)){
    for(int j=y;j<=n;j+=lowbit(j)){
      da[i][j]+=c;
    }
  }
}
int query(int x,int y){
  int sum=0;
  for(int i=x;i;i-=lowbit(i)){
    for(int j=y;j;j-=lowbit(j)){
      sum+=da[i][j];
    }
  }
  return sum;
}
int ask(int xa,int ya,int xb,int yb){
  return query(xb,yb)-query(xa-1,yb)-query(xb,ya-1)+query(xa-1,ya-1);
}
int main(){
  int tim;
  int tc=0;
  scanf("%d",&tim);
  while(tim--){
    n=1002;
    memset(da,0,sizeof(da));
    for(int i=1;i<=n;++i){
      for(int j=1;j<=n;++j){
        add(i,j,1);
      }
    }
    printf("Case %d:
", ++tc);
    scanf("%d",&q);
    while(q--){
      char op[2];
      int xa=0,xb=0,ya=0,yb=0,c;
      scanf("%s",op);
      if(op[0]=='S'){
        scanf("%d%d%d%d",&xa,&ya,&xb,&yb);
        xa++;ya++;xb++;yb++;
        if(xa>xb)swap(xa,xb);
        if(ya>yb)swap(ya,yb);
        printf("%d
", ask(xa,ya,xb,yb));
      }else if(op[0]=='A'){
        scanf("%d%d%d",&xa,&ya,&c);
        xa++;ya++;xb++;yb++;
        add(xa,ya,c);
      }else if(op[0]=='D'){
        scanf("%d%d%d",&xa,&ya,&c);
        xa++;ya++;xb++;yb++;
        c=min(c,ask(xa,ya,xa,ya));
        add(xa,ya,-c);
      }else{
        scanf("%d%d%d%d%d",&xa,&ya,&xb,&yb,&c);
        xa++;ya++;xb++;yb++;
        c=min(c,ask(xa,ya,xa,ya));
        add(xa,ya,-c);
        add(xb,yb,c);
      }
    }
  }
  return 0;
}

HDU2642:

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<assert.h>
#include<bitset>
#define lson rt<<1
#define rson rt<<1|1
#define lowbit(x) (x)&(-(x))
#define all(x) (x).begin(),(x).end()
using namespace std;
typedef long long LL;
const int INF = 0x3f3f3f3f;
const int N = (int)1e3 +107;
int ar[N][N], da[N][N], di[N][N], dj[N][N],dij[N][N];
int n, m, q;
//sumxy = (x+1)(y+1)sigma(d[i][j])-(y+1)sigma(i*d[i][j])-(x+1)sigma(j*d[i][j])+sigma(i*j*d[i][j])
void add(int x, int y, int z){
  for(int i=x;i<=n;i+=lowbit(i)){
    for(int j=y;j<=n;j+=lowbit(j)){
      da[i][j] += z;
    }
  }
}
int query(int x, int y){
  int res = 0;
  for(int i = x; i>0; i -= lowbit(i)){
    for(int j = y; j>0; j -= lowbit(j)){
      res += da[i][j];
    }
  }
  return res;
}
int ask(int xa,int ya,int xb,int yb){
  return query(xb,yb)-query(xb,ya-1)-query(xa-1,yb)+query(xa-1,ya-1);
}
int main(){
  while(~scanf("%d", &q)){
    n=1001;
    memset(da,0,sizeof(da));
    memset(ar,0,sizeof(ar));
    while(q--){
      char op[2];
      int xa,xb,ya,yb,c;
      scanf("%s",op);
      if(op[0]=='B'){
        scanf("%d%d",&xa,&ya);
        xa++;ya++;
        if(ar[xa][ya]==0)add(xa,ya,1);
        ar[xa][ya]=1;
      }else if(op[0]=='D'){
        scanf("%d%d",&xa,&ya);
        xa++;ya++;
        if(ar[xa][ya]==1)add(xa,ya,-1);
        ar[xa][ya]=0;
      }else{
        scanf("%d%d",&xa,&xb);
        scanf("%d%d",&ya,&yb);
        xa++;ya++;
        xb++;yb++;
        if(xa>xb)swap(xa,xb);
        if(ya>yb)swap(ya,yb);
        printf("%d
", ask(xa,ya,xb,yb));
      }
    }
  }
  return 0;
}

hdu1541

#include<iostream>
#include<algorithm>
#include<cstring>
#include<cstdio>
using namespace std;
const int N=1e5+10;
int a[N],sum[N];
int lowbit(int x){return x&(-x);}
int Sum(int n){
    int sum=0;
    while(n>0){
        sum+=a[n];
        n-=lowbit(n);
    }
    return sum;
}
void add(int x){
    while(x<=N){
        ++a[x];
        x+=lowbit(x);
    }
}
int main() {
    int x,y,n;
    while(~scanf("%d",&n)){
        memset(a,0,sizeof(a));
        memset(sum,0,sizeof(sum));
        for(int i=0;i<n;++i){
            scanf("%d %d",&x,&y);
            //1 5 7 3 5
            //2 6 8 4 6
            //0 1 2 1 3
            x++;
            sum[Sum(x)]++;
            add(x);
        }
        for(int i=0;i<n;++i){
            printf("%d
",sum[i]);
        }
    }
    return 0;
}

具体请访问这个博客:here

树状数组求区间最值

(bit[x]) 是区间([x-lowbit(x)+1, x])的最值
能转移到(x)的状态是 (x-2^0, x-2^1 ... x-2^k)(2^k < lowbit(x))
(y - lowbit(y) >= x),则(query(x,y) = max(bit[y], query(x, y-lowbit(y))))
(y - lowbit(y) < x),则(query(x,y) = max(ar[y], query(x, y-1)))

#include<iostream>
#include<cstdio>
#include<assert.h>
#include<ctime>
#include<algorithm>
#include<cstring>
//#include<bits/stdc++.h>
#define lowbit(x) (x)&(-(x))
#define all(x) x.begin(),x.end()
#define iis std::ios::sync_with_stdio(false)
#define mme(a,b) memset((a),(b),sizeof((a)))
using namespace std;
typedef long long LL;
const int MXN = 2e5+7;
const int INF = 0x3f3f3f3f;
const int MOD = 1e9 + 7;
int n, m;
int ar[MXN], bit[MXN];

int lowbit(int x){return x & (-x);}
inline void mymax(int &a,int b){a = a > b? a: b;}
void update(int x){
  while(x <= n){
    bit[x] = ar[x];
    int tmp = lowbit(x);
    for(int i = 1; i < tmp; i <<= 1){
      //bit[x] = max(bit[x], bit[x-i]);
      mymax(bit[x], bit[x-i]);
    }
    x += lowbit(x);
  }
}
int query(int x,int y){
  int ans = 0;
  while(y >= x){
    //ans = max(ans, ar[y]);
    mymax(ans, ar[y]);
    --y;
    for( ; y - lowbit(y) >= x; y -= lowbit(y)){
      //ans = max(ans, bit[y]);
      mymax(ans, bit[y]);
    }
  }
  return ans;
}
int main(){
  char op;
  int a, b;
  while(scanf("%d%d", &n, &m)!=EOF){
    for(register int i = 0; i <= n; ++i)bit[i] = 0;
    for(register int i = 1; i <= n; ++i){
      scanf("%d", &ar[i]);
      update(i);
    }
    while(m--){
      scanf("%c", &op);
      scanf("%c", &op);
      scanf("%d%d", &a, &b);
      if(op == 'U') {
        ar[a] = b;
        update(a);
      }else {
        a = query(a, b);
        printf("%d
", a);
      }
    }
  }
  return 0;
}

原文地址:https://www.cnblogs.com/Cwolf9/p/9513252.html