Fast Matrix Operations

A Simple Problem with Integers

每次将区间向下更新,或是用之前的方法,统计当前节点到父节点处的覆盖数目。

#include <cstdio>
#include <iostream>
using namespace std;

const int MAXN = 100005;

typedef long long int64;

int d[MAXN];

class SegNode {
public:
    int L, R;
    int64 c, sum;
    int64 get_c() { return c * (R - L + 1); }
    void log(const char *info) {
        printf("%s: [%d %d]: %lld, %lld.
", info, L, R, c, sum);
    }
} node[MAXN * 4];

class SegTree {
public:
    void log(const char *info) {
        printf("%s:
", info);
        printf("{%d %d}, %lld, %lld.
", node[3].L, node[3].R, node[3].c, node[3].sum);
    }
    void build(int r, int L, int R) {
        node[r].L = L;
        node[r].R = R;
        node[r].c = 0;
        if (L == R) {
            node[r].sum = d[L]; 
        } else {
            int M = (L + R) / 2;
            build(2 * r, L, M);
            build(2 * r + 1, M + 1, R);
            node[r].sum = node[2 * r].sum + node[2 * r + 1].sum;
        }
    }
    int64 query(int r, int L, int R) {
        if (L <= node[r].L && node[r].R <= R) {
            return node[r].sum + node[r].get_c();
        } else {
            node[2 * r].c += node[r].c;
            node[2 * r + 1].c += node[r].c;
            int64 res = 0;
            if (L <= node[2 * r].R) {
                res += query(2 * r, L, R); 
            }
            if (R >= node[2 * r + 1].L) {
                res += query(2 * r + 1, L, R); 
            }
            node[r].c = 0;
            node[r].sum = node[2 * r].sum + node[2 * r + 1].sum + node[2 * r].get_c() + node[2 * r + 1].get_c();
            //node[r].log("query");
            return res;
        }   
    }
    void insert(int r, int L, int R, int c) {
        if (L <= node[r].L && node[r].R <= R) {
            node[r].c += c;
        } else {
            node[2 * r].c += node[r].c;
            node[2 * r + 1].c += node[r].c;
            if (L <= node[2 * r].R) {
                insert(2 * r, L, R, c);
            } 
            if (R >= node[2 * r + 1].L) {
                insert(2 * r + 1, L, R, c);
            }
            node[r].c = 0;
            node[r].sum = node[2 * r].sum + node[2 * r + 1].sum + node[2 * r].get_c() + node[2 * r + 1].get_c();
        } 
        //log("tree");
        //node[r].log("insert");
    }
    /*{{{ insert2*/
    void insert2(int r, int L, int R, int c) {
        if (L <= node[r].L && node[r].R <= R) {
            node[r].c += c;
        } else {
            if (L <= node[2 * r].R) {
                insert(2 * r, L, R, c);
            } 
            if (R >= node[2 * r + 1].L) {
                insert(2 * r + 1, L, R, c);
            }
            node[r].sum = node[2 * r].sum + node[2 * r + 1].sum + node[2 * r].get_c() + node[2 * r + 1].get_c();
        } 
    }
    /*}}}*/
    /*{{{ query2*/
    int64 query2(int r, int L, int R, int dd) {
        dd += node[r].c;
        if (L <= node[r].L && node[r].R <= R) {
            return node[r].sum + (node[r].R - node[r].L + 1) * dd;
        } else {
            int res = 0;
            if (L <= node[2 * r].R) {
                res += query(2 * r, L, R); 
            }
            if (R >= node[2 * r + 1].L) {
                res += query(2 * r + 1, L, R); 
            }
            return res;
        }   
    }
    /*}}}*/
};

int main() {
    int n, q;
    while (scanf("%d%d", &n, &q) != EOF) {
    SegTree tree;
    for (int i = 1; i <= n; i++) scanf("%d", &d[i]);
    tree.build(1, 1, n); 
    while (q--) {
        char ch[2];
        int a, b;
        scanf("%s%d%d", ch, &a, &b);
        if (ch[0] == 'C') {
            int c;
            scanf("%d", &c);
            tree.insert(1, a, b, c); 
            //tree.insert2(1, a, b, c); 
        } else if (ch[0] == 'Q') {
            printf("%lld
", tree.query(1, a, b)); 
            /*
            int dd = 0;
            printf("%lld
", tree.query2(1, a, b, dd)); 
            */
        }
    }
    }
}

Fast Matrix Operations

需要注意的是:

1. 插入及查询在树上向下遍历时,不然是否有遍历,都应该将节点上的覆盖数目向下传递;

2. 树构建的时候,一些节点会构建不出来,这种类型的节点,在插入向孩子节点遍历的时候,之前判断失败的条件,可能会判断成功,从而导致错误;

3. 节点中的sum表示的是当前节点构成的树除了当前节点上的覆盖数以外的所有数的和。

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;

typedef long long int64;

class SegNode {
public:
    int L, R, B, T;
    int64 v, m_min, m_max, m_sum;
    bool is_clear;
    SegNode* sons[4];
    SegNode() {
        v = m_min = m_max = m_sum = 0;
        is_clear = false;
        memset(sons, NULL, sizeof(sons)); 
    }
    int area() {
        return (R - L + 1) * (T - B + 1);
    }
};

class SegTree {
public:
    void free(SegNode *node) {
        for (int i = 0; i < 4; i++) {
            if (node->sons[i] != NULL) {
                free(node->sons[i]);
            }
        }
        if (node != NULL) {
            delete node;
            node = NULL;
        }
    }
    void build(SegNode* &node, int L, int R, int B, int T) {
        node = new SegNode();
        node->L = L; node->R = R; node->B = B; node->T = T;
        if (L == R && B == T) {
            // leaf
        } else {
            // non leaf
            int M1 = (L + R) / 2;
            int M2 = (B + T) / 2;
            if (L <= M1 && M2 + 1 <= T) build(node->sons[0], L, M1, M2 + 1, T);
            if (M1 + 1 <= R && M2 + 1 <= T) build(node->sons[1], M1 + 1, R, M2 + 1, T);
            if (L <= M1 && B <= M2) build(node->sons[2], L, M1, B, M2);
            if (M1 + 1 <= R && B <= M2) build(node->sons[3], M1 + 1, R, B, M2);
        }
    }
    void insert(SegNode *node, int L, int R, int B, int T, int v, int k) {
        //node->log();
        if (L <= node->L && node->R <= R && B <= node->B && node->T <= T) {
            if (k == 1) node->v += v;
            else if (k == 2) {
                node->v = v;
                node->m_min = node->m_max = node->m_sum = 0;
                node->is_clear = true;
            }
        } else {
            int M1 = (node->L + node->R) / 2;
            int M2 = (node->B + node->T) / 2;
            for (int i = 0; i < 4; i++) {
                if (node->sons[i] != NULL) {
                    down(node, node->sons[i]);
                }
            }
            if (L <= M1 && T >= M2 + 1) {
                if (node->sons[0] != NULL)
                insert(node->sons[0], L, R, B, T, v, k);
            }
            if (R >= M1 + 1 && T >= M2 + 1) {
                if (node->sons[1] != NULL)
                insert(node->sons[1], L, R, B, T, v, k);
            }
            if (L <= M1 && B <= M2) {
                if (node->sons[2] != NULL)
                insert(node->sons[2], L, R, B, T, v, k);
            }
            if (R >= M1 + 1 && B <= M2) {
                if (node->sons[3] != NULL)
                insert(node->sons[3], L, R, B, T, v, k); 
            }
            // clear node[r]
            node->is_clear = false;
            node->v = 0;
            update(node);
        }               
    }
    void down(SegNode *r, SegNode *t) {
        r->is_clear;
        if (r->is_clear) {
            t->is_clear = true;
            t->v = r->v;
            //
            t->m_min = t->m_max = t->m_sum = 0;
        } else {
            t->v += r->v;
        }    
    }
    void update(SegNode *r) {
        bool need = true;
        for (int i = 0; i < 4; i++) {
            if (r->sons[i] != NULL) {
                if (need) {
                    need = false;
                    r->m_min = r->sons[i]->m_min + r->sons[i]->v;
                    r->m_max = r->sons[i]->m_max + r->sons[i]->v;
                    r->m_sum = r->sons[i]->m_sum + r->sons[i]->v * r->sons[i]->area();
                } else {
                    r->m_min = min(r->m_min, r->sons[i]->m_min + r->sons[i]->v);
                    r->m_max = max(r->m_max, r->sons[i]->m_max + r->sons[i]->v);
                    r->m_sum += r->sons[i]->m_sum + r->sons[i]->v * r->sons[i]->area();
                }
            } 
        }
    }
    void query(SegNode *node, int L, int R, int B, int T, int64& mmin, int64& mmax, int64& msum) {
        //node->log();
        if (L <= node->L && node->R <= R && B <= node->B && node->T <= T) {
            mmin = min(mmin, node->m_min + node->v);
            mmax = max(mmax, node->m_max + node->v);
            msum += node->m_sum + node->v * node->area();
        } else {
            int M1 = (node->L + node->R) / 2;
            int M2 = (node->B + node->T) / 2;
            for (int i = 0; i < 4; i++) {
                if (node->sons[i] != NULL) {
                    down(node, node->sons[i]);
                }
            }
            if (L <= M1 && T >= M2 + 1) {
                if (node->sons[0] != NULL)
                query(node->sons[0], L, R, B, T, mmin, mmax, msum);
            }
            if (R >= M1 + 1 && T >= M2 + 1) {
                if (node->sons[1] != NULL)
                query(node->sons[1], L, R, B, T, mmin, mmax, msum);
            }
            if (L <= M1 && B <= M2) {
                if (node->sons[2] != NULL)
                query(node->sons[2], L, R, B, T, mmin, mmax, msum);
            }
            if (R >= M1 + 1 && B <= M2) {
                if (node->sons[3] != NULL)
                query(node->sons[3], L, R, B, T, mmin, mmax, msum);
            }
            // clear node[r]
            node->is_clear = false;
            node->v = 0;
            update(node);
        }               
    }
};

int main() {
    //freopen("fast.in", "r", stdin);

    int r, c, m;
    while (scanf("%d%d%d", &r, &c, &m) != EOF) {
        SegTree tree;
        SegNode *root = NULL;
        tree.build(root, 1, r, 1, c);
        for (int i = 0; i < m; i++) {
            int k, x1, y1, x2, y2, v;
            scanf("%d%d%d%d%d", &k, &x1, &y1, &x2, &y2);
            if (k == 3) {
                int64 mmin = 1e9, mmax = -1e9, msum = 0;
                tree.query(root, x1, x2, y1, y2, mmin, mmax, msum);
                printf("%lld %lld %lld
", msum, mmin, mmax);
            } else {
                scanf("%d", &v); 
                tree.insert(root, x1, x2, y1, y2, v, k);
            }
        }
        tree.free(root);
    }
}
原文地址:https://www.cnblogs.com/litstrong/p/3300076.html