[HDU6155]Subsequence Count 线段树+矩阵优化DP

题目链接
DP方程不难推:
状态设为dp[i][0]和dp[i][1],表示从第一位到这一位以1/0结尾的序列数是多少。
假如当前位是1,那么 (egin{equation} left{ egin{aligned} dp[i][1] &= dp[i-1][0] + dp[i-1][1] + 1 \ dp[i][0] &= dp[i-1][0] \ end{aligned} ight. end{equation})
最后的+1是考虑只有一个1的序列。一开始我觉得这样不能保存已经计算出来的序列,后来发现,当一个序列结尾加上1变成一个新序列,这个序列确实没了,但之前产生这个序列的序列会继续产生这个序列。这里比较难理解,可以手动模拟一下。
假如当前位是0,和1同理。
但如果每次都从l到r做dp,复杂度过高,根据经验,可以用矩阵+线段树来降低这一过程的时间复杂度。

如果当前位是1,那么就表示为(egin{bmatrix} 1 & 0 & 0 \ 1 & 1 & 1 \ 0 & 0 & 1 end{bmatrix})

如果当前位是0,那么就表示为(egin{bmatrix} 1 & 1 & 1 \ 0 & 1 & 0 \ 0 & 0 & 1 end{bmatrix})

我们就有(egin{bmatrix} 1 & 0 & 0 \ 1 & 1 & 1 \ 0 & 0 & 1 end{bmatrix} imes egin{bmatrix} dp[i][0] \ dp[i][1] \ 1 end{bmatrix} = egin{bmatrix} dp[i+1][0] \ dp[i+1][1] \ 1 end{bmatrix})

事实上我这里推错了,应该把dp放在左边,然后转移矩阵放右边。不然线段树处理,就是从r乘到l了。为了纪念这个困扰我几天的错误,我决定在博客里把它写出来...
虽然这样对不起读者但也能让读者在知道思路的情况下自己尝试推导正确的转移矩阵。
然后如何处理1和0的互相转换呢,不难发现两种转移矩阵想要互相变换,只需要左乘一个初等方阵交换1,2列,右乘一个初等方阵交换1,2行,对于([l,r])区间的矩阵全部转换,每个矩阵都左乘右乘,然后通过矩阵乘法的结合律,化简为:

(egin{bmatrix} 0 & 1 & 0 \ 1 & 0 & 0 \ 0 & 0 & 1 end{bmatrix} imes A_l imes A_{l+1} imes ... imes A_r imes egin{bmatrix} 0 & 1 & 0 \ 1 & 0 & 0 \ 0 & 0 & 1 end{bmatrix})

因为矩阵乘法具有结合律,通过带lazyTag的线段树不难实现这样的一个矩阵区间乘法和查询问题。
代码结构很明确,先定义Matrix类,定义乘法,然后设几个矩阵常数,然后是线段树。
非AC代码:

#include<iostream>
#include<cstdio>
#define maxn 100005
using namespace std;
int T, N, Q ,type;
char s[maxn];
const int type1[3][3] = {{1, 0, 0}, {1, 1, 1}, {0, 0, 1}};
const int type0[3][3] = {{1, 1, 1}, {0, 1, 0}, {0, 0, 1}};
const int trans[3][3] = {{0, 1, 0}, {1, 0, 0}, {0, 0, 1}};
struct Matrix{
    int arr[3][3];
    //type2
    Matrix operator * (Matrix a){
        Matrix ans;
        for (int i = 0; i < 3;i++)
            for (int j = 0; j <= 3;j++){
                int cnt = 0;
                for (int k = 0; k < 3;k++)
                    cnt += arr[i][k] * a.arr[k][j];
                ans.arr[i][j] = cnt;
            }
        return ans;
    }
    void emplace(){
        for (int i = 0; i < 3;i++)
            for (int j = 0; j < 3;j++)
                if(i==j)
                    arr[i][j] = 1;
                else
                    arr[i][j] = 0;
    }
    void init(int type){
        if(type==1)
            for (int i = 0; i < 3;i++)
                for (int j = 0; j < 3;j++)
                    arr[i][j] = type1[i][j];
        if(type==0)
            for (int i = 0; i < 3;i++)
                for (int j = 0; j < 3;j++)
                    arr[i][j] = type0[i][j];
        if(type==2)
            for (int i = 0; i < 3;i++)
                for (int j = 0; j < 3;j++)
                    arr[i][j] = trans[i][j];
    }
}Trans;
struct node{
    int l, r;
    Matrix num,tag;
} tree[maxn << 2];
void push_up(int rt){
    tree[rt].num = tree[rt << 1].num * tree[rt << 1 | 1].num;
}
void push_down(int rt){
    tree[rt << 1].tag = tree[rt << 1 | 1].tag = tree[rt].tag;
    tree[rt << 1].num = tree[rt << 1].tag * tree[rt << 1].num * tree[rt << 1].tag;
    tree[rt << 1 | 1].num = tree[rt << 1 | 1].tag * tree[rt << 1 | 1].num * tree[rt << 1 | 1].tag;
    tree[rt].tag.emplace();
}
void build(int rt,int l,int r){
    tree[rt].l = l;
    tree[rt].r = r;
    if(l==r){
        if(s[l]=='1')
            tree[rt].num.init(1);
        if(s[l]=='0')
            tree[rt].num.init(0);
        return;
    }
    int mid = l + r >> 1;
    build(rt << 1, l, mid);
    build(rt << 1 | 1, mid + 1, r);
    push_up(rt);
}
void update(int rt,int l,int r){
    if(l<=tree[rt].l&&tree[rt].r<=r){
        tree[rt].num = Trans * tree[rt].num * Trans;
        tree[rt].tag = tree[rt].tag * Trans;
        return;
    }
    push_down(rt);
    int mid = tree[rt].l + tree[rt].r >> 1;
    if(l<=mid)
        update(rt << 1, l, r);
    if(r>mid)
        update(rt << 1 | 1, l, r);
    push_up(rt);
}
Matrix query(int rt,int l,int r){
    if(l<=tree[rt].l&&tree[rt].r<=r)
        return tree[rt].num;
    push_down(rt);
    Matrix ans;
    ans.emplace();
    int mid = tree[rt].l + tree[rt].r >> 1;
    if(l<=mid)
        ans = ans * query(rt << 1, l, r);
    if(r>mid)
        ans = ans * query(rt << 1 | 1, l, r);
    return ans;
}
void output(Matrix ans){
    int ansn = 1;
    for (int i = 0; i < 3;i++)
        ansn *= ans.arr[1][i];
    printf("%d
", ansn);
}
int main(){
    scanf("%d",&T);
    Trans.init(2);
    while(T--){
        scanf("%d%d", &N,&Q);
        scanf("%s", s + 1);
        build(1, 1, N);
        int l, r;
        while(Q--){
            scanf("%d%d%d", &type, &l, &r);
            if(type==1)
                update(1, l, r);
            else
                output(query(1, l, r));
        }
    }
}
原文地址:https://www.cnblogs.com/sherrlock/p/14500360.html