洛谷P3328(bzoj 4085)毒瘤线段树

题面及大致思路:https://www.cnblogs.com/Yangrui-Blog/p/9623294.html, https://www.cnblogs.com/New-Godess/p/4567282.html

每个点维护2个矩阵,一共15个变量。矩阵a: [a(i - 1), a(i), a(i + 1); b(i - 1), b(i), b(i + 1)], 矩阵b就是a(i - 1), a(i), a(i + 1)与b(i - 1), b(i), b(i + 1)的两两乘积,矩阵转移的过程很显然,就不细说了。这个题的思维难度不高,就是两点很烦人:1 卡常 2 维护变量很麻烦。

代码:

#include <bits/stdc++.h>
#define ls(x) (x << 1)
#define rs(x) ((x << 1) | 1)
using namespace std;
const int mod = 1000000007;
const int maxn = 300010;
int c[maxn][4];
int a, b, inv;
int qpow(int x, int y) {
    int ans = 1;
    for(; y; y >>= 1) {
        if(y & 1) ans = 1ll * ans * x % mod;
        x = 1ll * x * x % mod;
    }
    return ans;
}
struct Matrix {
    int a[3][3], n, m;
    void init(int x) {
        memset(a, 0, sizeof(a));
        n = m = x;
        for (int i = 0; i < n; i++) a[i][i] = 1;
    }
    Matrix operator * (const Matrix &rhs) const {
        Matrix ret; memset(ret.a, 0, sizeof(ret.a));
        ret.n = n, ret.m = rhs.m;
        for (int i = 0; i < n; i++)
            for (int k = 0; k < m; k++)
                for (int j = 0; j < 2; j++)
                    ret.a[i][j] = (ret.a[i][j] + 1ll * a[i][k] * rhs.a[k][j] % mod) % mod;
        ret.a[2][2] = 1;
        return ret;
    }
};
Matrix A, B, p[35], E;
Matrix qpow(int k) {
    Matrix ans = E;
    for (int i = 1; k; i++, k >>= 1) {
        if(k & 1) ans = ans * p[i];
    }
    return ans;
}
void get_Matrix(int pos) {
    Matrix tmp = A * qpow(c[pos][0] - 2);
    c[pos][1] = tmp.a[0][1], c[pos][2] = tmp.a[0][0];
}
void init_p() {
    A.n = 1, A.m = 3;
    A.a[0][0] = 2, A.a[0][1] = 1, A.a[0][2] = 1;
    B.n = B.m = 3;
    B.a[0][0] = 1; B.a[1][0] = a, B.a[2][0] = b;
    B.a[0][1] = 1; B.a[2][2] = 1;
    for (int i = 1; i <= 32; i++) {
        p[i] = B;
        B = B * B;
    }
}
struct SegementTree {
    int sum[2][3], val[3][3], l, r, lz[2];
};
SegementTree tr[maxn * 4];
inline void pushup(int now) {
    for (int i = 0; i < 3; i++)
        for (int j = 0; j < 3; j++)
            tr[now].val[i][j] = (tr[ls(now)].val[i][j] + tr[rs(now)].val[i][j]) % mod;
    for (int i = 0; i < 2; i++)
        for (int j = 0; j < 3; j++)
            tr[now].sum[i][j] = (tr[ls(now)].sum[i][j] + tr[rs(now)].sum[i][j]) % mod;
}
inline void add1(int now, int flag) {
    int l = tr[now].l, r = tr[now].r;
    for (int i = 0; i < 2; i++)
        tr[now].sum[flag][i] = tr[now].sum[flag][i + 1];
    tr[now].sum[flag][2] = (tr[now].sum[flag][1] + 1ll * tr[now].sum[flag][0] * a + 1ll * b * (r - l + 1)) % mod;
    if(flag == 0) {
        for (int i = 0; i < 2; i++)
            for (int j = 0 ; j < 3; j++)
             tr[now].val[i][j] = tr[now].val[i + 1][j];
        for (int j = 0; j < 3; j++)
            tr[now].val[2][j] = (tr[now].val[1][j] + 1ll * tr[now].val[0][j] * a + 1ll * b * tr[now].sum[1][j]) % mod;
    } else {
        for (int i = 0; i < 3; i++)
            for (int j = 0; j < 2; j++)
                tr[now].val[i][j] = tr[now].val[i][j + 1];
        for (int i = 0; i < 3; i++)
            tr[now].val[i][2] = (tr[now].val[i][1] + 1ll * tr[now].val[i][0] * a + 1ll * b * tr[now].sum[0][i]) % mod;
    }
}
inline void dec1(int now, int flag) {
    int l = tr[now].l, r = tr[now].r;
    if(a == 0) {
        for (int i = 1; i >= 0; i--) tr[now].sum[flag][i + 1] = tr[now].sum[flag][i];
        tr[now].sum[flag][0] = (tr[now].sum[flag][1] - 1ll * b * (r - l + 1) % mod + mod) % mod;
        if(flag == 0) {
            for (int i = 1; i >= 0; i--)
                for (int j = 0; j < 3; j++)
                    tr[now].val[i + 1][j] = tr[now].val[i][j];
            for (int i = 0; i < 3; i++)
                tr[now].val[0][i] = (tr[now].val[1][i] - 1ll * b * tr[now].sum[1][i] % mod + mod) % mod;
        } else {
            for (int i = 0; i < 3; i++)
                for (int j = 1; j >= 0; j--)
                    tr[now].val[i][j + 1] = tr[now].val[i][j];
            for (int i = 0; i < 3; i++)
                tr[now].val[i][0] = (tr[now].val[i][1] - 1ll * b * tr[now].sum[0][i] % mod + mod) % mod;
        }
        return;
    }
    for (int i = 1; i >= 0; i--)
        tr[now].sum[flag][i + 1] = tr[now].sum[flag][i];
    tr[now].sum[flag][0] = ((tr[now].sum[flag][2] - tr[now].sum[flag][1] - 1ll * b * (r - l + 1) % mod) * inv % mod + mod) % mod;
    if(flag == 0) {
        for (int i = 1; i >= 0; i--)
            for (int j = 0; j < 3; j++)
                tr[now].val[i + 1][j] = tr[now].val[i][j];
        for (int i = 0; i < 3; i++)
            tr[now].val[0][i] = ((tr[now].val[2][i] - tr[now].val[1][i] - 1ll * b * tr[now].sum[1][i] % mod) % mod * inv % mod + mod) % mod;
    } else {
        for (int i = 0; i < 3; i++)
            for (int j = 1; j >= 0; j--)
                tr[now].val[i][j + 1] = tr[now].val[i][j];
        for (int i = 0; i < 3; i++)
            tr[now].val[i][0] = ((tr[now].val[i][2] - tr[now].val[i][1] - 1ll * b * tr[now].sum[0][i] % mod) % mod * inv % mod + mod) % mod;
    }
}
inline void pushdown(int now, int flag, int y) {
    if(y > 0) for (int i = 1; i <= y; i++) add1(now, flag);
    else for (int i = -1; i >= y; i--) dec1(now, flag);
}
inline void Pushdown(int now) {
    for (int flag = 0; flag < 2; flag++) {
        pushdown(ls(now), flag, tr[now].lz[flag]);
        pushdown(rs(now), flag, tr[now].lz[flag]);
        tr[ls(now)].lz[flag] += tr[now].lz[flag];
        tr[rs(now)].lz[flag] += tr[now].lz[flag];
        tr[now].lz[flag] = 0;
    }
}
inline void build(int now, int l, int r) {
    tr[now].l = l, tr[now].r = r;
    if(l == r) {
        for (int i = 0; i < 3; i++) {
            tr[now].sum[0][i] = c[l - 1][i + 1];
            tr[now].sum[1][i] = c[l + 1][i + 1]; 
        }
        for (int i = 0; i < 3; i++)
            for (int j = 0;j < 3; j++)
                tr[now].val[i][j] = 1ll * tr[now].sum[0][i] * tr[now].sum[1][j] % mod;
        return;
    }
    int mid = (l + r) >> 1;
    build(ls(now), l, mid);
    build(rs(now), mid + 1, r);
    pushup(now);
}
inline void update(int now, int ql, int qr, int flag,int val) {
    int l = tr[now].l, r = tr[now].r;
    if(l > qr || r < ql) return;
    if(l >= ql && r <= qr) {
        tr[now].lz[flag] += val;
        pushdown(now, flag, val);
        return;
    }
    Pushdown(now);
    int mid = (l + r) >> 1;
    if(ql <= mid)	update(ls(now), ql, qr, flag, val);
    if(qr > mid)	update(rs(now), ql, qr, flag, val);
    pushup(now);
}
inline int query(int now, int ql, int qr) {
    int l = tr[now].l, r = tr[now].r;
    if(l > qr || r < ql) return 0;
    if(l >= ql && r <= qr) return tr[now].val[2][0];
    Pushdown(now);
    int mid = (l + r) >> 1;
    int ans = 0;
    if(ql <= mid) ans = (ans + query(ls(now), ql, qr)) % mod;
    if(qr > mid) ans = (ans + query(rs(now), ql, qr)) % mod;
    return ans;
}
int main() {
    int n, m;
    scanf("%d%d%d%d", &n, &m, &a, &b);
    inv = qpow(a, mod - 2);
    E.init(3);
    init_p();
    for (int i = 1; i <= n; i++) {
        scanf("%d", &c[i][0]);
        get_Matrix(i);
        c[i][3] = (c[i][2] + 1ll * a * c[i][1] % mod + b) % mod;
    }
    char op[10];
    build(1, 2, n - 1);
    while(m--) {
        scanf("%s", op + 1);
        int x, y;
        scanf("%d%d", &x, &y);
        if(op[1] == 'p') {
            update(1, x + 1, y + 1, 0, 1);
            update(1, x - 1, y - 1, 1, 1);
        } else if (op[1] == 'm') {
            update(1, x + 1, y + 1, 0, -1);
            update(1, x - 1, y - 1, 1, -1);
        } else {
            printf("%d
", query(1, x + 1, y - 1));
        }
    }
}

  

原文地址:https://www.cnblogs.com/pkgunboat/p/10629869.html